From 6ec4f6237970a9cbf4cb8defd42ab8620b6d7c92 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Fri, 10 Oct 2025 06:32:06 +0000 Subject: [PATCH 01/25] refactor: separate statistic computation we also make it lazy --- spras/analysis/summary.py | 44 +++----------------- spras/statistics.py | 88 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 38 deletions(-) create mode 100644 spras/statistics.py diff --git a/spras/analysis/summary.py b/spras/analysis/summary.py index c8abc1cad..fd70db8f3 100644 --- a/spras/analysis/summary.py +++ b/spras/analysis/summary.py @@ -1,10 +1,11 @@ from pathlib import Path -from statistics import median from typing import Iterable import networkx as nx import pandas as pd +from spras.statistics import compute_statistics, statistics_options + def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, algo_params: dict[str, dict], algo_with_params: list) -> pd.DataFrame: @@ -47,44 +48,11 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg # Save the network name, number of nodes, number edges, and number of connected components nw_name = str(file_path) - number_nodes = nw.number_of_nodes() - number_edges = nw.number_of_edges() - ncc = nx.number_connected_components(nw) - - # Save the max/median degree, average clustering coefficient, and density - if number_nodes == 0: - max_degree = 0 - median_degree = 0.0 - density = 0.0 - else: - degrees = [deg for _, deg in nw.degree()] - max_degree = max(degrees) - median_degree = median(degrees) - density = nx.density(nw) - - cc = list(nx.connected_components(nw)) - # Save the max diameter - # Use diameter only for components with ≥2 nodes (singleton components have diameter 0) - diameters = [ - nx.diameter(nw.subgraph(c).copy()) if len(c) > 1 else 0 - for c in cc - ] - max_diameter = max(diameters, default=0) - - # Save the average path lengths - # Compute average shortest path length only for components with ≥2 nodes (undefined for singletons, set to 0.0) - avg_path_lengths = [ - nx.average_shortest_path_length(nw.subgraph(c).copy()) if len(c) > 1 else 0.0 - for c in cc - ] - - if len(avg_path_lengths) != 0: - avg_path_len = sum(avg_path_lengths) / len(avg_path_lengths) - else: - avg_path_len = 0.0 + + graph_statistics = compute_statistics(nw, statistics_options) # Initialize list to store current network information - cur_nw_info = [nw_name, number_nodes, number_edges, ncc, density, max_degree, median_degree, max_diameter, avg_path_len] + cur_nw_info = [nw_name, *graph_statistics] # Iterate through each node property and save the intersection with the current network for node_list in nodes_by_col: @@ -104,7 +72,7 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg nw_info.append(cur_nw_info) # Prepare column names - col_names = ['Name', 'Number of nodes', 'Number of edges', 'Number of connected components', 'Density', 'Max degree', 'Median degree', 'Max diameter', 'Average path length'] + col_names = ['Name', *statistics_options] col_names.extend(nodes_by_col_labs) col_names.append('Parameter combination') diff --git a/spras/statistics.py b/spras/statistics.py new file mode 100644 index 000000000..843e5292a --- /dev/null +++ b/spras/statistics.py @@ -0,0 +1,88 @@ +""" +Graph statistics, used to power summary.py. + +We allow for arbitrary computation of any specific statistic on some graph, +computing more than necessary if we have dependencies. See the top level +`statistics_computation` dictionary for usage. +""" + +import itertools +import networkx as nx +from statistics import median +from typing import Callable + +def compute_degree(graph: nx.DiGraph) -> tuple[int, float]: + """ + Computes the (max, median) degree of a `graph`. + """ + # number_of_nodes is a cheap call + if graph.number_of_nodes() == 0: + return (0, 0.0) + else: + degrees = [deg for _, deg in graph.degree()] + return max(degrees), median(degrees) + +def compute_on_cc(graph: nx.DiGraph) -> tuple[int, float]: + cc = list(nx.connected_components(graph)) + # Save the max diameter + # Use diameter only for components with ≥2 nodes (singleton components have diameter 0) + diameters = [ + nx.diameter(graph.subgraph(c).copy()) if len(c) > 1 else 0 + for c in cc + ] + max_diameter = max(diameters, default=0) + + # Save the average path lengths + # Compute average shortest path length only for components with ≥2 nodes (undefined for singletons, set to 0.0) + avg_path_lengths = [ + nx.average_shortest_path_length(graph.subgraph(c).copy()) if len(c) > 1 else 0.0 + for c in cc + ] + + if len(avg_path_lengths) != 0: + avg_path_len = sum(avg_path_lengths) / len(avg_path_lengths) + else: + avg_path_len = 0.0 + + return max_diameter, avg_path_len + +# The type signature on here is quite bad. I would like to say that an n-tuple has n-outputs. +statistics_computation: dict[tuple[str, ...], Callable[[nx.DiGraph], tuple[float | int, ...]]] = { + ('Number of nodes',): lambda graph : (graph.number_of_nodes(),), + ('Number of edges',): lambda graph : (graph.number_of_edges(),), + ('Number of connected components',): lambda graph : (nx.number_connected_components(graph),), + ('Density',): lambda graph : (nx.density(graph),), + + ('Max degree', 'Median degree'): compute_degree, + ('Max diameter', 'Average path length'): compute_on_cc, +} + +# All of the keys inside statistics_computation, flattened. +statistics_options: list[str] = list(itertools.chain(*(list(key) for key in statistics_computation.keys()))) + +def compute_statistics(graph: nx.DiGraph, statistics: list[str]) -> dict[str, float | int]: + """ + Computes `statistics` for a graph corresponding to the top-level `statistics` dictionary + in this file. + """ + + # early-scan cutoff for statistics: + # we want to err as soon as possible + for stat in statistics: + if stat not in statistics_options: + raise RuntimeError(f"Statistic {stat} not a computable statistics! Available statistics: {statistics_options}") + + # now, we can compute statistics only + computed_statistics: dict[str, float | int] = dict() + for statistic_tuple, compute in statistics_computation.items(): + # when we want them + if not set(statistic_tuple).isdisjoint(set(statistics)): + computed_tuple = compute(graph) + assert len(statistic_tuple) == computed_tuple, f"bad tuple length for {statistic_tuple}" + + current_computed_statistics = zip(statistic_tuple, computed_tuple) + for stat, value in current_computed_statistics: + computed_statistics[stat] = value + + # (and return only the statistics we wanted) + return {key: computed_statistics[key] for key in statistics} From 9987189d8e0d9a9006ae1897cd44836500a5c906 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Fri, 10 Oct 2025 06:48:54 +0000 Subject: [PATCH 02/25] fix: correct tuple assumption --- spras/statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spras/statistics.py b/spras/statistics.py index 843e5292a..ac91b80a9 100644 --- a/spras/statistics.py +++ b/spras/statistics.py @@ -78,7 +78,7 @@ def compute_statistics(graph: nx.DiGraph, statistics: list[str]) -> dict[str, fl # when we want them if not set(statistic_tuple).isdisjoint(set(statistics)): computed_tuple = compute(graph) - assert len(statistic_tuple) == computed_tuple, f"bad tuple length for {statistic_tuple}" + assert len(statistic_tuple) == len(computed_tuple), f"bad tuple length for {statistic_tuple}" current_computed_statistics = zip(statistic_tuple, computed_tuple) for stat, value in current_computed_statistics: From 25eef5e72aee4fb7aea6f6b5e9d11dff7fd5be16 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Fri, 10 Oct 2025 07:06:46 +0000 Subject: [PATCH 03/25] fix: stably use graph statistic values --- spras/analysis/summary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spras/analysis/summary.py b/spras/analysis/summary.py index fd70db8f3..432dba0a4 100644 --- a/spras/analysis/summary.py +++ b/spras/analysis/summary.py @@ -52,7 +52,7 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg graph_statistics = compute_statistics(nw, statistics_options) # Initialize list to store current network information - cur_nw_info = [nw_name, *graph_statistics] + cur_nw_info = [nw_name, *graph_statistics.values()] # Iterate through each node property and save the intersection with the current network for node_list in nodes_by_col: From cb373c130760c7040b16ec03ba1d2673e343465b Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Wed, 29 Oct 2025 17:56:22 -0700 Subject: [PATCH 04/25] style: fmt --- spras/config/config.py | 4 ++-- spras/statistics.py | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/spras/config/config.py b/spras/config/config.py index 22e655941..add815d9d 100644 --- a/spras/config/config.py +++ b/spras/config/config.py @@ -71,7 +71,7 @@ def __init__(self, raw_config: dict[str, Any]): self.container_prefix: str = DEFAULT_CONTAINER_PREFIX # A Boolean specifying whether to unpack singularity containers. Default is False self.unpack_singularity = False - # A Boolean indiciating whether to enable container runtime profiling (apptainer/singularity only) + # A Boolean indicating whether to enable container runtime profiling (apptainer/singularity only) self.enable_profiling = False # A dictionary to store configured datasets against which SPRAS will be run self.datasets = None @@ -308,7 +308,7 @@ def process_config(self, raw_config: RawConfig): if raw_config.container_registry and raw_config.container_registry.base_url != "" and raw_config.container_registry.owner != "": self.container_prefix = raw_config.container_registry.base_url + "/" + raw_config.container_registry.owner - if raw_config.enable_profiling and not raw_config.container_framework in ["singularity", "apptainer"]: + if raw_config.enable_profiling and raw_config.container_framework not in ["singularity", "apptainer"]: warnings.warn("enable_profiling is set to true, but the container framework is not singularity/apptainer. This setting will have no effect.") self.enable_profiling = raw_config.enable_profiling diff --git a/spras/statistics.py b/spras/statistics.py index ac91b80a9..49ae8b3fc 100644 --- a/spras/statistics.py +++ b/spras/statistics.py @@ -7,10 +7,12 @@ """ import itertools -import networkx as nx from statistics import median from typing import Callable +import networkx as nx + + def compute_degree(graph: nx.DiGraph) -> tuple[int, float]: """ Computes the (max, median) degree of a `graph`. @@ -43,7 +45,7 @@ def compute_on_cc(graph: nx.DiGraph) -> tuple[int, float]: avg_path_len = sum(avg_path_lengths) / len(avg_path_lengths) else: avg_path_len = 0.0 - + return max_diameter, avg_path_len # The type signature on here is quite bad. I would like to say that an n-tuple has n-outputs. @@ -52,7 +54,7 @@ def compute_on_cc(graph: nx.DiGraph) -> tuple[int, float]: ('Number of edges',): lambda graph : (graph.number_of_edges(),), ('Number of connected components',): lambda graph : (nx.number_connected_components(graph),), ('Density',): lambda graph : (nx.density(graph),), - + ('Max degree', 'Median degree'): compute_degree, ('Max diameter', 'Average path length'): compute_on_cc, } @@ -63,7 +65,7 @@ def compute_on_cc(graph: nx.DiGraph) -> tuple[int, float]: def compute_statistics(graph: nx.DiGraph, statistics: list[str]) -> dict[str, float | int]: """ Computes `statistics` for a graph corresponding to the top-level `statistics` dictionary - in this file. + in this file. """ # early-scan cutoff for statistics: @@ -71,7 +73,7 @@ def compute_statistics(graph: nx.DiGraph, statistics: list[str]) -> dict[str, fl for stat in statistics: if stat not in statistics_options: raise RuntimeError(f"Statistic {stat} not a computable statistics! Available statistics: {statistics_options}") - + # now, we can compute statistics only computed_statistics: dict[str, float | int] = dict() for statistic_tuple, compute in statistics_computation.items(): From 4640bc0c3fcf57f2427d2bb0200381d7ce8ad6cb Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Thu, 30 Oct 2025 01:13:01 +0000 Subject: [PATCH 05/25] feat: init intervals and heuristics --- spras/config/schema.py | 50 ++++++++++ spras/interval.py | 203 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 spras/interval.py diff --git a/spras/config/schema.py b/spras/config/schema.py index 2b46aaf1c..fbacb8793 100644 --- a/spras/config/schema.py +++ b/spras/config/schema.py @@ -13,9 +13,12 @@ import re from typing import Annotated, Optional +import networkx as nx from pydantic import AfterValidator, BaseModel, ConfigDict, Field from spras.config.util import CaseInsensitiveEnum +from spras.interval import Interval +from spras.statistics import compute_statistics, statistics_options # Most options here have an `include` property, # which is meant to make disabling parts of the configuration easier. @@ -148,6 +151,51 @@ class ReconstructionSettings(BaseModel): model_config = ConfigDict(extra='forbid') +class GraphHeuristics(BaseModel): + number_of_nodes: list[Interval] = [] + number_of_edges: list[Interval] = [] + number_of_connected_components: list[Interval] = [] + density: list[Interval] = [] + + max_degree: list[Interval] = [] + median_degree: list[Interval] = [] + max_diameter: list[Interval] = [] + average_path_length: list[Interval] = [] + + def validate_graph(self, graph: nx.DiGraph): + statistics_dictionary = { + 'Number of nodes': self.number_of_nodes, + 'Number of edges': self.number_of_edges, + 'Number of connected components': self.number_of_connected_components, + 'Density': self.density, + 'Max degree': self.max_degree, + 'Median degree': self.median_degree, + 'Max diameter': self.max_diameter, + 'Average path length': self.average_path_length + } + + # quick assert: is statistics_dictionary exhaustive? + assert set(statistics_dictionary.keys()) == set(statistics_options) + + stats = compute_statistics( + graph, + list(k for k, v in statistics_dictionary.items() if len(v) == 0) + ) + + for key, value in stats.items(): + intervals = statistics_dictionary[key] + + matches_heuristics = False + for interval in intervals: + if interval.mem(value): + matches_heuristics = True + break + + if not matches_heuristics: + raise RuntimeError(f"Heuristic {key} with value {value} does not match {intervals}!") + + model_config = ConfigDict(extra='forbid') + class RawConfig(BaseModel): # TODO: move these container values to a nested container key container_framework: ContainerFramework = ContainerFramework.docker @@ -165,6 +213,8 @@ class RawConfig(BaseModel): reconstruction_settings: ReconstructionSettings + heuristics: GraphHeuristics = GraphHeuristics() + # We include use_attribute_docstrings here to preserve the docstrings # after attributes at runtime (for future JSON schema generation) model_config = ConfigDict(extra='forbid', use_attribute_docstrings=True) diff --git a/spras/interval.py b/spras/interval.py new file mode 100644 index 000000000..6771a228b --- /dev/null +++ b/spras/interval.py @@ -0,0 +1,203 @@ +""" +Utilities for defining inequality intervals (e.g. l < x <= u) + +For graph heuristics, we allow inequality intervals of the form (num) < (id)?. For example, +we can say "1500 <" for "1500 < x", or "1000 < x < 2000", etc. + +[If there is ever a library that does this, we should replace this code with that library.] +""" + +from enum import Enum +from typing import Any, Optional, Self +from pydantic import BaseModel, model_validator + +class Operand(Enum): + LT = "<" + LTE = "<=" + EQ = "=" + GTE = ">=" + GT = ">" + + @classmethod + def from_str(cls, string: str) -> Optional[Self]: + return next((enum for enum in list(cls) if enum.value == string), None) + + def is_closed(self) -> bool: + """Whether this is a closed inequality. We consider = to be closed.""" + match self: + case Operand.LTE: return True + case Operand.EQ: return True + case Operand.GT: return True + return False + + def as_closed(self): + """Closes an operand. Eq does not get modified.""" + match self: + case Operand.LT: return Operand.LTE + case Operand.GT: return Operand.GTE + return self + + def as_opened(self): + """Opens an operand. Eq does not get modified.""" + match self: + case Operand.LTE: return Operand.LT + case Operand.GTE: return Operand.GT + return self + + def with_closed(self, closed: bool): return self.as_closed() if closed else self.as_opened() + + def compare(self, left, right) -> bool: + match self: + case Operand.LT: return left < right + case Operand.LTE: return left <= right + case Operand.EQ: return left == right + case Operand.GTE: return left >= right + case Operand.GT: return left > right + + @classmethod + def combine(cls, left: Self, right: Self): + """Combines two operands, returning None if the operands don't combine well.""" + match (left, right): + case (Operand.LTE, Operand.LTE): return Operand.LTE + case (Operand.LT, Operand.LTE): return Operand.LT + case (Operand.LT, Operand.LT): return Operand.LT + case (Operand.EQ, op): return op + case (op, Operand.EQ): return op + case (Operand.GTE, Operand.GTE): return Operand.GTE + case (Operand.GT, Operand.GTE): return Operand.GT + case (Operand.GT, Operand.GT): return Operand.GT + return None + +class Interval(BaseModel): + lower: Optional[float] + upper: Optional[float] + lower_closed: bool + upper_closed: bool + + def mem(self, num: float) -> bool: + if self.lower is not None: + meets_lower = self.lower <= num if self.lower_closed else self.lower < num + else: + meets_lower = True + + if self.upper is not None: + meets_upper = num <= self.upper if self.upper_closed else num < self.upper + else: + meets_upper = True + + return meets_lower and meets_upper + + @classmethod + def single(cls, num: float) -> Self: + return cls(lower=num, upper=num, lower_closed=True, upper_closed=True) + + @classmethod + def left_operand(cls, operand: Operand, num: float) -> Self: + """Creates an interval whose operand is on the left (e.g. <300)""" + match operand: + case Operand.LT: return cls(lower=None, upper=num, lower_closed=False, upper_closed=False) + case Operand.LTE: return cls(lower=None, upper=num, lower_closed=True, upper_closed=False) + case Operand.EQ: return cls.single(num) + case Operand.GTE: return cls(lower=num, upper=None, lower_closed=False, upper_closed=False) + case Operand.GT: return cls(lower=num, upper=None, lower_closed=False, upper_closed=True) + + @classmethod + def right_operand(cls, num: float, operand: Operand) -> Self: + """Creates an interval whose operand is on the right (e.g. 300<)""" + match operand: + case Operand.LT: return cls(lower=num, upper=None, lower_closed=False, upper_closed=False) + case Operand.LTE: return cls(lower=num, upper=None, lower_closed=True, upper_closed=False) + case Operand.EQ: return cls.single(num) + case Operand.GTE: return cls(lower=None, upper=num, lower_closed=False, upper_closed=False) + case Operand.GT: return cls(lower=None, upper=num, lower_closed=False, upper_closed=True) + + @classmethod + def from_string(cls, input: str) -> Self: + tokens = [token.strip() for token in input.split(" ")] + + assert len(tokens) != 0 + + def parse_num(numstr: str) -> Optional[int]: + # Allow pythonic separators + try: + return int(numstr.replace("_", "")) + except: + return None + + def is_id(idstr: str) -> bool: return idstr.isidentifier() + + # Case 1: (id?) operand number + if is_id(tokens[0]): + # No other cases have an id at the beginning: we get rid of it. + tokens.pop() + + operand = Operand.from_str(tokens[0]) + if operand is not None: + # (cont.) Case 1: (id?) operand number + number = parse_num(tokens[1]) + assert number is not None, f"found operand {operand.value} and expected a number, but found {tokens[1]} instead." + return cls.left_operand(operand, number) + + # All other cases have a number + number = parse_num(tokens.pop()) + assert number is not None, f"expected an inequality, got {input} instead" + + # Case 2: number + if len(tokens) == 0: + return cls.single(number) + + # All other cases have an operand + operand = Operand.from_str(tokens.pop()) + assert operand is not None, f"got {number}, expected an operand afterward." + + # Case 3: number operand (id?) + if len(tokens) == 0 or len(tokens) == 1: + if len(tokens) == 1: assert is_id(tokens[1]) + return cls.right_operand(number, operand) + + # Case 4: number operand id operand number + id = tokens.pop() + assert is_id(id), f"got an inequality of the form {number} {operand.value} and expected nothing or another identifier, but got {id} instead." + + second_operand_str = tokens.pop() + second_operand = Operand.from_str(second_operand_str) + assert second_operand is not None, f"got an inequality of the form {number} {operand.value} {id} and was expecting an operand, but got {second_operand_str} instead." + + second_number_str = tokens.pop() + second_number = parse_num(second_number_str) + assert second_number is not None, f"got an inequality of the form {number} {operand.value} {id} {second_operand.value} and was expecting a number, but got {second_number_str} instead." + + # don't want equals operands in a double inequality (a < b < c) + assert operand is not Operand.EQ and second_operand is not Operand.EQ, f"in a double inequality, neither operand can be '='!" + + # are our two numbers valid? + combined_operand = Operand.combine(operand, second_operand) + assert combined_operand is not None, f"operands {operand.value} and {second_operand} must combine well with each other!" + assert combined_operand.compare(number, second_number), f"{number} {operand.value} {second_number} does not hold!" + + return cls( + lower=number, + upper=second_number, + lower_closed=operand.is_closed(), + upper_closed=second_operand.is_closed() + ) + + def __str__(self) -> str: + if not self.lower and not self.upper: return "{empty interval}" + if not self.lower: + return Operand.LT.with_closed(self.upper_closed).value + " " + str(self.upper) + if not self.upper: + return str(self.lower) + " " + Operand.LT.with_closed(self.lower_closed).value + + if self.lower == self.upper and self.lower_closed and self.upper_closed: return str(self.lower) + + return str(self.lower) + " " + Operand.LT.with_closed(self.lower_closed).value + " " + "x" \ + + Operand.LT.with_closed(self.upper_closed).value + str(self.upper) + + # For parsing Intervals automatically with pydantic. + @model_validator(mode="before") + @classmethod + def from_literal(cls, data: Any) -> Any: + if isinstance(data, str): + return cls.from_string(data) + return data From 898d568a49053467d74af1cb952bdceac400436d Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Wed, 29 Oct 2025 18:15:23 -0700 Subject: [PATCH 06/25] style: specify zip strict --- spras/statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spras/statistics.py b/spras/statistics.py index 49ae8b3fc..1ebe7cc62 100644 --- a/spras/statistics.py +++ b/spras/statistics.py @@ -82,7 +82,7 @@ def compute_statistics(graph: nx.DiGraph, statistics: list[str]) -> dict[str, fl computed_tuple = compute(graph) assert len(statistic_tuple) == len(computed_tuple), f"bad tuple length for {statistic_tuple}" - current_computed_statistics = zip(statistic_tuple, computed_tuple) + current_computed_statistics = zip(statistic_tuple, computed_tuple, strict=True) for stat, value in current_computed_statistics: computed_statistics[stat] = value From 8177ed665b13714e56991eaef8ee45e96c1f0dbb Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Thu, 30 Oct 2025 01:41:12 +0000 Subject: [PATCH 07/25] refactor: use heuristic error, mv heuristics outside of main schema file --- spras/config/heuristics.py | 69 ++++++++++++++++++++++++++++++++++++++ spras/config/schema.py | 49 +-------------------------- 2 files changed, 70 insertions(+), 48 deletions(-) create mode 100644 spras/config/heuristics.py diff --git a/spras/config/heuristics.py b/spras/config/heuristics.py new file mode 100644 index 000000000..3f42c6aba --- /dev/null +++ b/spras/config/heuristics.py @@ -0,0 +1,69 @@ +import networkx as nx +from pydantic import BaseModel, ConfigDict +from spras.interval import Interval +from spras.statistics import compute_statistics, statistics_options + +class GraphHeuristicError(RuntimeError): + """ + Represents an error arising from a graph algorithm output + not meeting the necessary graph heuristisc. + """ + failed_heuristics: list[tuple[str, float | int, list[Interval]]] + + @staticmethod + def to_string(failed_heuristics: list[tuple[str, float | int, list[Interval]]]): + return f"The following heuristics failed: {failed_heuristics}" + + def __init__(self, failed_heuristics: list[tuple[str, float | int, list[Interval]]]): + super().__init__(GraphHeuristicError.to_string(failed_heuristics)) + + self.failed_heuristics = failed_heuristics + + def __str__(self) -> str: + return GraphHeuristicError.to_string(self.failed_heuristics) + +class GraphHeuristics(BaseModel): + number_of_nodes: Interval | list[Interval] = [] + number_of_edges: Interval | list[Interval] = [] + number_of_connected_components: Interval | list[Interval] = [] + density: Interval | list[Interval] = [] + + max_degree: Interval | list[Interval] = [] + median_degree: Interval | list[Interval] = [] + max_diameter: Interval | list[Interval] = [] + average_path_length: Interval | list[Interval] = [] + + def validate_graph(self, graph: nx.DiGraph): + statistics_dictionary = { + 'Number of nodes': self.number_of_nodes, + 'Number of edges': self.number_of_edges, + 'Number of connected components': self.number_of_connected_components, + 'Density': self.density, + 'Max degree': self.max_degree, + 'Median degree': self.median_degree, + 'Max diameter': self.max_diameter, + 'Average path length': self.average_path_length + } + + # quick assert: is statistics_dictionary exhaustive? + assert set(statistics_dictionary.keys()) == set(statistics_options) + + stats = compute_statistics( + graph, + list(k for k, v in statistics_dictionary.items() if isinstance(v, list) and len(v) == 0) + ) + + failed_heuristics: list[tuple[str, float | int, list[Interval]]] = [] + for key, value in stats.items(): + intervals = statistics_dictionary[key] + if not isinstance(intervals, list): intervals = [intervals] + + for interval in intervals: + if interval.mem(value): + failed_heuristics.append((key, value, intervals)) + break + + if len(failed_heuristics) != 0: + raise GraphHeuristicError(failed_heuristics) + + model_config = ConfigDict(extra='forbid') diff --git a/spras/config/schema.py b/spras/config/schema.py index 846188e4c..3678a541c 100644 --- a/spras/config/schema.py +++ b/spras/config/schema.py @@ -13,13 +13,11 @@ import re from typing import Annotated, Optional -import networkx as nx from pydantic import AfterValidator, BaseModel, ConfigDict from spras.config.container_schema import ContainerSettings +from spras.config.heuristics import GraphHeuristics from spras.config.util import CaseInsensitiveEnum -from spras.interval import Interval -from spras.statistics import compute_statistics, statistics_options # Most options here have an `include` property, # which is meant to make disabling parts of the configuration easier. @@ -140,51 +138,6 @@ class ReconstructionSettings(BaseModel): model_config = ConfigDict(extra='forbid') -class GraphHeuristics(BaseModel): - number_of_nodes: list[Interval] = [] - number_of_edges: list[Interval] = [] - number_of_connected_components: list[Interval] = [] - density: list[Interval] = [] - - max_degree: list[Interval] = [] - median_degree: list[Interval] = [] - max_diameter: list[Interval] = [] - average_path_length: list[Interval] = [] - - def validate_graph(self, graph: nx.DiGraph): - statistics_dictionary = { - 'Number of nodes': self.number_of_nodes, - 'Number of edges': self.number_of_edges, - 'Number of connected components': self.number_of_connected_components, - 'Density': self.density, - 'Max degree': self.max_degree, - 'Median degree': self.median_degree, - 'Max diameter': self.max_diameter, - 'Average path length': self.average_path_length - } - - # quick assert: is statistics_dictionary exhaustive? - assert set(statistics_dictionary.keys()) == set(statistics_options) - - stats = compute_statistics( - graph, - list(k for k, v in statistics_dictionary.items() if len(v) == 0) - ) - - for key, value in stats.items(): - intervals = statistics_dictionary[key] - - matches_heuristics = False - for interval in intervals: - if interval.mem(value): - matches_heuristics = True - break - - if not matches_heuristics: - raise RuntimeError(f"Heuristic {key} with value {value} does not match {intervals}!") - - model_config = ConfigDict(extra='forbid') - class RawConfig(BaseModel): containers: ContainerSettings enable_profiling: bool = False From fac110847104dd77fa46bcb3d67bd8af4d4df6c7 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Thu, 30 Oct 2025 06:34:20 +0000 Subject: [PATCH 08/25] fix: proper tokenization --- spras/config/heuristics.py | 4 +++- spras/interval.py | 30 ++++++++++++++++++------------ 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/spras/config/heuristics.py b/spras/config/heuristics.py index 3f42c6aba..99ecb4695 100644 --- a/spras/config/heuristics.py +++ b/spras/config/heuristics.py @@ -1,8 +1,10 @@ import networkx as nx from pydantic import BaseModel, ConfigDict + from spras.interval import Interval from spras.statistics import compute_statistics, statistics_options + class GraphHeuristicError(RuntimeError): """ Represents an error arising from a graph algorithm output @@ -18,7 +20,7 @@ def __init__(self, failed_heuristics: list[tuple[str, float | int, list[Interval super().__init__(GraphHeuristicError.to_string(failed_heuristics)) self.failed_heuristics = failed_heuristics - + def __str__(self) -> str: return GraphHeuristicError.to_string(self.failed_heuristics) diff --git a/spras/interval.py b/spras/interval.py index 6771a228b..6b304075d 100644 --- a/spras/interval.py +++ b/spras/interval.py @@ -7,10 +7,14 @@ [If there is ever a library that does this, we should replace this code with that library.] """ +import tokenize from enum import Enum +from io import BytesIO from typing import Any, Optional, Self + from pydantic import BaseModel, model_validator + class Operand(Enum): LT = "<" LTE = "<=" @@ -21,7 +25,7 @@ class Operand(Enum): @classmethod def from_str(cls, string: str) -> Optional[Self]: return next((enum for enum in list(cls) if enum.value == string), None) - + def is_closed(self) -> bool: """Whether this is a closed inequality. We consider = to be closed.""" match self: @@ -43,7 +47,7 @@ def as_opened(self): case Operand.LTE: return Operand.LT case Operand.GTE: return Operand.GT return self - + def with_closed(self, closed: bool): return self.as_closed() if closed else self.as_opened() def compare(self, left, right) -> bool: @@ -53,7 +57,7 @@ def compare(self, left, right) -> bool: case Operand.EQ: return left == right case Operand.GTE: return left >= right case Operand.GT: return left > right - + @classmethod def combine(cls, left: Self, right: Self): """Combines two operands, returning None if the operands don't combine well.""" @@ -79,12 +83,12 @@ def mem(self, num: float) -> bool: meets_lower = self.lower <= num if self.lower_closed else self.lower < num else: meets_lower = True - + if self.upper is not None: meets_upper = num <= self.upper if self.upper_closed else num < self.upper else: meets_upper = True - + return meets_lower and meets_upper @classmethod @@ -113,8 +117,10 @@ def right_operand(cls, num: float, operand: Operand) -> Self: @classmethod def from_string(cls, input: str) -> Self: - tokens = [token.strip() for token in input.split(" ")] - + # We can't do a normal string#split here for cases like "1500<" + tokens = [t.string for t in tokenize.tokenize(BytesIO(input.encode('utf-8')).readline) if t.string != ""] + tokens.pop() # drop utf-8 indicator + assert len(tokens) != 0 def parse_num(numstr: str) -> Optional[int]: @@ -130,18 +136,18 @@ def is_id(idstr: str) -> bool: return idstr.isidentifier() if is_id(tokens[0]): # No other cases have an id at the beginning: we get rid of it. tokens.pop() - + operand = Operand.from_str(tokens[0]) if operand is not None: # (cont.) Case 1: (id?) operand number number = parse_num(tokens[1]) assert number is not None, f"found operand {operand.value} and expected a number, but found {tokens[1]} instead." return cls.left_operand(operand, number) - + # All other cases have a number number = parse_num(tokens.pop()) assert number is not None, f"expected an inequality, got {input} instead" - + # Case 2: number if len(tokens) == 0: return cls.single(number) @@ -158,7 +164,7 @@ def is_id(idstr: str) -> bool: return idstr.isidentifier() # Case 4: number operand id operand number id = tokens.pop() assert is_id(id), f"got an inequality of the form {number} {operand.value} and expected nothing or another identifier, but got {id} instead." - + second_operand_str = tokens.pop() second_operand = Operand.from_str(second_operand_str) assert second_operand is not None, f"got an inequality of the form {number} {operand.value} {id} and was expecting an operand, but got {second_operand_str} instead." @@ -174,7 +180,7 @@ def is_id(idstr: str) -> bool: return idstr.isidentifier() combined_operand = Operand.combine(operand, second_operand) assert combined_operand is not None, f"operands {operand.value} and {second_operand} must combine well with each other!" assert combined_operand.compare(number, second_number), f"{number} {operand.value} {second_number} does not hold!" - + return cls( lower=number, upper=second_number, From 2e0d8d0faec96ffd5fcfb49496645f9464e3f24c Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Thu, 30 Oct 2025 07:13:24 +0000 Subject: [PATCH 09/25] fix(interval): correct parsing --- spras/interval.py | 16 ++++++++-------- test/test_interval.py | 6 ++++++ 2 files changed, 14 insertions(+), 8 deletions(-) create mode 100644 test/test_interval.py diff --git a/spras/interval.py b/spras/interval.py index 6b304075d..52e4ab802 100644 --- a/spras/interval.py +++ b/spras/interval.py @@ -119,7 +119,7 @@ def right_operand(cls, num: float, operand: Operand) -> Self: def from_string(cls, input: str) -> Self: # We can't do a normal string#split here for cases like "1500<" tokens = [t.string for t in tokenize.tokenize(BytesIO(input.encode('utf-8')).readline) if t.string != ""] - tokens.pop() # drop utf-8 indicator + tokens.pop(0) # drop utf-8 indicator assert len(tokens) != 0 @@ -135,7 +135,7 @@ def is_id(idstr: str) -> bool: return idstr.isidentifier() # Case 1: (id?) operand number if is_id(tokens[0]): # No other cases have an id at the beginning: we get rid of it. - tokens.pop() + tokens.pop(0) operand = Operand.from_str(tokens[0]) if operand is not None: @@ -145,15 +145,15 @@ def is_id(idstr: str) -> bool: return idstr.isidentifier() return cls.left_operand(operand, number) # All other cases have a number - number = parse_num(tokens.pop()) - assert number is not None, f"expected an inequality, got {input} instead" + number = parse_num(tokens.pop(0)) + assert number is not None, f"expected a number, got {input} instead" # Case 2: number if len(tokens) == 0: return cls.single(number) # All other cases have an operand - operand = Operand.from_str(tokens.pop()) + operand = Operand.from_str(tokens.pop(0)) assert operand is not None, f"got {number}, expected an operand afterward." # Case 3: number operand (id?) @@ -162,14 +162,14 @@ def is_id(idstr: str) -> bool: return idstr.isidentifier() return cls.right_operand(number, operand) # Case 4: number operand id operand number - id = tokens.pop() + id = tokens.pop(0) assert is_id(id), f"got an inequality of the form {number} {operand.value} and expected nothing or another identifier, but got {id} instead." - second_operand_str = tokens.pop() + second_operand_str = tokens.pop(0) second_operand = Operand.from_str(second_operand_str) assert second_operand is not None, f"got an inequality of the form {number} {operand.value} {id} and was expecting an operand, but got {second_operand_str} instead." - second_number_str = tokens.pop() + second_number_str = tokens.pop(0) second_number = parse_num(second_number_str) assert second_number is not None, f"got an inequality of the form {number} {operand.value} {id} {second_operand.value} and was expecting a number, but got {second_number_str} instead." diff --git a/test/test_interval.py b/test/test_interval.py new file mode 100644 index 000000000..e263817bc --- /dev/null +++ b/test/test_interval.py @@ -0,0 +1,6 @@ +from spras.interval import Interval + +class TestInterval: + def test_number(self): + assert Interval.single(5) == Interval(lower=5, upper=5, lower_closed=True, upper_closed=True) + assert Interval.from_string("5") == Interval.single(5) From 183c3ad874fad9096ced9aa4cc4c6ccd88769687 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Thu, 30 Oct 2025 07:30:19 +0000 Subject: [PATCH 10/25] fix(interval): correct other parsing mistakes --- spras/interval.py | 53 ++++++++++++++++++++++++++++--------------- test/test_interval.py | 11 +++++++++ 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/spras/interval.py b/spras/interval.py index 52e4ab802..7e5f6c663 100644 --- a/spras/interval.py +++ b/spras/interval.py @@ -10,7 +10,7 @@ import tokenize from enum import Enum from io import BytesIO -from typing import Any, Optional, Self +from typing import Any, Optional, Self, cast from pydantic import BaseModel, model_validator @@ -31,7 +31,7 @@ def is_closed(self) -> bool: match self: case Operand.LTE: return True case Operand.EQ: return True - case Operand.GT: return True + case Operand.GTE: return True return False def as_closed(self): @@ -57,6 +57,14 @@ def compare(self, left, right) -> bool: case Operand.EQ: return left == right case Operand.GTE: return left >= right case Operand.GT: return left > right + + def flip(self): + match self: + case Operand.LT: return Operand.GT + case Operand.LTE: return Operand.GTE + case Operand.EQ: return Operand.EQ + case Operand.GTE: return Operand.LTE + case Operand.GT: return Operand.LT @classmethod def combine(cls, left: Self, right: Self): @@ -64,11 +72,13 @@ def combine(cls, left: Self, right: Self): match (left, right): case (Operand.LTE, Operand.LTE): return Operand.LTE case (Operand.LT, Operand.LTE): return Operand.LT + case (Operand.LTE, Operand.LT): return Operand.LT case (Operand.LT, Operand.LT): return Operand.LT case (Operand.EQ, op): return op case (op, Operand.EQ): return op case (Operand.GTE, Operand.GTE): return Operand.GTE case (Operand.GT, Operand.GTE): return Operand.GT + case (Operand.GTE, Operand.GT): return Operand.GT case (Operand.GT, Operand.GT): return Operand.GT return None @@ -108,12 +118,8 @@ def left_operand(cls, operand: Operand, num: float) -> Self: @classmethod def right_operand(cls, num: float, operand: Operand) -> Self: """Creates an interval whose operand is on the right (e.g. 300<)""" - match operand: - case Operand.LT: return cls(lower=num, upper=None, lower_closed=False, upper_closed=False) - case Operand.LTE: return cls(lower=num, upper=None, lower_closed=True, upper_closed=False) - case Operand.EQ: return cls.single(num) - case Operand.GTE: return cls(lower=None, upper=num, lower_closed=False, upper_closed=False) - case Operand.GT: return cls(lower=None, upper=num, lower_closed=False, upper_closed=True) + # TODO: remove cast? + return cast(Self, Interval.left_operand(operand.flip(), num)) @classmethod def from_string(cls, input: str) -> Self: @@ -158,7 +164,7 @@ def is_id(idstr: str) -> bool: return idstr.isidentifier() # Case 3: number operand (id?) if len(tokens) == 0 or len(tokens) == 1: - if len(tokens) == 1: assert is_id(tokens[1]) + if len(tokens) == 1: assert is_id(tokens[0]) return cls.right_operand(number, operand) # Case 4: number operand id operand number @@ -178,15 +184,23 @@ def is_id(idstr: str) -> bool: return idstr.isidentifier() # are our two numbers valid? combined_operand = Operand.combine(operand, second_operand) - assert combined_operand is not None, f"operands {operand.value} and {second_operand} must combine well with each other!" + assert combined_operand is not None, f"operands {operand.value} and {second_operand.value} must combine well with each other!" assert combined_operand.compare(number, second_number), f"{number} {operand.value} {second_number} does not hold!" - return cls( - lower=number, - upper=second_number, - lower_closed=operand.is_closed(), - upper_closed=second_operand.is_closed() - ) + if combined_operand.as_opened() == Operand.LT: + return cls( + lower=number, + upper=second_number, + lower_closed=operand.is_closed(), + upper_closed=second_operand.is_closed() + ) + else: + return cls( + lower=second_number, + upper=number, + lower_closed=second_operand.is_closed(), + upper_closed=operand.is_closed() + ) def __str__(self) -> str: if not self.lower and not self.upper: return "{empty interval}" @@ -197,8 +211,11 @@ def __str__(self) -> str: if self.lower == self.upper and self.lower_closed and self.upper_closed: return str(self.lower) - return str(self.lower) + " " + Operand.LT.with_closed(self.lower_closed).value + " " + "x" \ - + Operand.LT.with_closed(self.upper_closed).value + str(self.upper) + return str(self.lower) + " " + Operand.LT.with_closed(self.lower_closed).value + " x " \ + + Operand.LT.with_closed(self.upper_closed).value + " " + str(self.upper) + + def __repr__(self) -> str: + return f"Interval[{str(self)}]" # For parsing Intervals automatically with pydantic. @model_validator(mode="before") diff --git a/test/test_interval.py b/test/test_interval.py index e263817bc..1c3b69a61 100644 --- a/test/test_interval.py +++ b/test/test_interval.py @@ -4,3 +4,14 @@ class TestInterval: def test_number(self): assert Interval.single(5) == Interval(lower=5, upper=5, lower_closed=True, upper_closed=True) assert Interval.from_string("5") == Interval.single(5) + + def test_string_permutations(self): + assert Interval.from_string("<5") == Interval.from_string("< 5") + assert Interval.from_string("5<") == Interval.from_string("5 < x") + assert Interval.from_string("6<") == Interval.from_string("x > 6") + assert Interval.from_string("100 <") == Interval.from_string(">100") + assert Interval.from_string("200 >= x > 100") == Interval.from_string("100 < x <= 200") + + def test_orientation(self): + assert Interval.from_string("10<").upper is None + assert Interval.from_string("10<").lower == 10.0 From 0b6e01f59bfc1e8a25f525194d6548c82a0b11b9 Mon Sep 17 00:00:00 2001 From: "Tristan F.-R." Date: Thu, 6 Nov 2025 00:02:39 +0000 Subject: [PATCH 11/25] feat: integrate heuristics --- Snakefile | 1 + spras/config/heuristics.py | 15 +++++++++++++++ spras/interval.py | 2 +- spras/util.py | 2 ++ test/test_interval.py | 5 +++-- 5 files changed, 22 insertions(+), 3 deletions(-) diff --git a/Snakefile b/Snakefile index 02f019e8d..ffb5c9f8c 100644 --- a/Snakefile +++ b/Snakefile @@ -295,6 +295,7 @@ rule parse_output: params = reconstruction_params(wildcards.algorithm, wildcards.params).copy() params['dataset'] = input.dataset_file runner.parse_output(wildcards.algorithm, input.raw_file, output.standardized_file, params) + _config.config.heuristics.validate_graph_from_file(output.standardized_file) # TODO: reuse in the future once we make summary work for mixed graphs. See https://github.com/Reed-CompBio/spras/issues/128 # Collect summary statistics for a single pathway diff --git a/spras/config/heuristics.py b/spras/config/heuristics.py index 99ecb4695..2c003fe64 100644 --- a/spras/config/heuristics.py +++ b/spras/config/heuristics.py @@ -1,3 +1,5 @@ +import os + import networkx as nx from pydantic import BaseModel, ConfigDict @@ -69,3 +71,16 @@ def validate_graph(self, graph: nx.DiGraph): raise GraphHeuristicError(failed_heuristics) model_config = ConfigDict(extra='forbid') + + def validate_graph_from_file(self, path: str | os.PathLike): + # TODO: re-use from summary.py once we have a mixed/hypergraph library + G = nx.read_edgelist(path, data=(('weight', float), ('Direction', str)), create_using=nx.DiGraph) + + # We explicitly use `list` here to stop add_edge + # from expanding our iterator infinitely. + for source, target, data in list(G.edges(data=True)): + if data["Direction"] == 'U': + G.add_edge(target, source, data) + pass + + return self.validate_graph(G) diff --git a/spras/interval.py b/spras/interval.py index 7e5f6c663..266997fd5 100644 --- a/spras/interval.py +++ b/spras/interval.py @@ -57,7 +57,7 @@ def compare(self, left, right) -> bool: case Operand.EQ: return left == right case Operand.GTE: return left >= right case Operand.GT: return left > right - + def flip(self): match self: case Operand.LT: return Operand.GT diff --git a/spras/util.py b/spras/util.py index ce2cc2f96..594a2a3ad 100644 --- a/spras/util.py +++ b/spras/util.py @@ -102,6 +102,8 @@ def raw_pathway_df(raw_pathway_file: str, sep: str = '\t', header: int = None) - return df +def output_pathw + def duplicate_edges(df: pd.DataFrame) -> tuple[pd.DataFrame, bool]: """ diff --git a/test/test_interval.py b/test/test_interval.py index 1c3b69a61..840f8d057 100644 --- a/test/test_interval.py +++ b/test/test_interval.py @@ -1,17 +1,18 @@ from spras.interval import Interval + class TestInterval: def test_number(self): assert Interval.single(5) == Interval(lower=5, upper=5, lower_closed=True, upper_closed=True) assert Interval.from_string("5") == Interval.single(5) - + def test_string_permutations(self): assert Interval.from_string("<5") == Interval.from_string("< 5") assert Interval.from_string("5<") == Interval.from_string("5 < x") assert Interval.from_string("6<") == Interval.from_string("x > 6") assert Interval.from_string("100 <") == Interval.from_string(">100") assert Interval.from_string("200 >= x > 100") == Interval.from_string("100 < x <= 200") - + def test_orientation(self): assert Interval.from_string("10<").upper is None assert Interval.from_string("10<").lower == 10.0 From 33e004f92c2ca09ee7956633f3fb210a5dcd83d8 Mon Sep 17 00:00:00 2001 From: "Tristan F.-R." Date: Thu, 6 Nov 2025 00:56:20 +0000 Subject: [PATCH 12/25] fix: drop random code --- spras/util.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/spras/util.py b/spras/util.py index 594a2a3ad..ce2cc2f96 100644 --- a/spras/util.py +++ b/spras/util.py @@ -102,8 +102,6 @@ def raw_pathway_df(raw_pathway_file: str, sep: str = '\t', header: int = None) - return df -def output_pathw - def duplicate_edges(df: pd.DataFrame) -> tuple[pd.DataFrame, bool]: """ From c675eced3b62b8a62204d9f6105628e1cdc09045 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Thu, 6 Nov 2025 02:22:45 +0000 Subject: [PATCH 13/25] fix: make undirected for determining number of connected components --- spras/statistics.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spras/statistics.py b/spras/statistics.py index 1ebe7cc62..222051d23 100644 --- a/spras/statistics.py +++ b/spras/statistics.py @@ -24,7 +24,8 @@ def compute_degree(graph: nx.DiGraph) -> tuple[int, float]: degrees = [deg for _, deg in graph.degree()] return max(degrees), median(degrees) -def compute_on_cc(graph: nx.DiGraph) -> tuple[int, float]: +def compute_on_cc(directed_graph: nx.DiGraph) -> tuple[int, float]: + graph: nx.Graph = directed_graph.to_undirected() cc = list(nx.connected_components(graph)) # Save the max diameter # Use diameter only for components with ≥2 nodes (singleton components have diameter 0) @@ -52,7 +53,7 @@ def compute_on_cc(graph: nx.DiGraph) -> tuple[int, float]: statistics_computation: dict[tuple[str, ...], Callable[[nx.DiGraph], tuple[float | int, ...]]] = { ('Number of nodes',): lambda graph : (graph.number_of_nodes(),), ('Number of edges',): lambda graph : (graph.number_of_edges(),), - ('Number of connected components',): lambda graph : (nx.number_connected_components(graph),), + ('Number of connected components',): lambda graph : (nx.number_connected_components(graph.to_undirected()),), ('Density',): lambda graph : (nx.density(graph),), ('Max degree', 'Median degree'): compute_degree, From 1cdaf121c84e08bb8a2c9c607e3286ffdbfdda0a Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Wed, 5 Nov 2025 18:23:13 -0800 Subject: [PATCH 14/25] fix: specify heuristics in wrapping config object --- Snakefile | 2 ++ spras/config/config.py | 2 ++ spras/config/heuristics.py | 4 ++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/Snakefile b/Snakefile index ffb5c9f8c..65e24f15d 100644 --- a/Snakefile +++ b/Snakefile @@ -295,6 +295,8 @@ rule parse_output: params = reconstruction_params(wildcards.algorithm, wildcards.params).copy() params['dataset'] = input.dataset_file runner.parse_output(wildcards.algorithm, input.raw_file, output.standardized_file, params) + # TODO: cache heuristics result, store partial heuristics configuration file + # to allow this rule to update when heuristics change _config.config.heuristics.validate_graph_from_file(output.standardized_file) # TODO: reuse in the future once we make summary work for mixed graphs. See https://github.com/Reed-CompBio/spras/issues/128 diff --git a/spras/config/config.py b/spras/config/config.py index 25e6f72de..51d3daf3b 100644 --- a/spras/config/config.py +++ b/spras/config/config.py @@ -78,6 +78,8 @@ def __init__(self, raw_config: dict[str, Any]): self.container_settings = ProcessedContainerSettings.from_container_settings(parsed_raw_config.containers, self.hash_length) # The list of algorithms to run in the workflow. Each is a dict with 'name' as an expected key. self.algorithms = None + # The heuristic handler + self.heuristics = parsed_raw_config.heuristics # A nested dict mapping algorithm names to dicts that map parameter hashes to parameter combinations. # Only includes algorithms that are set to be run with 'include: true'. self.algorithm_params = None diff --git a/spras/config/heuristics.py b/spras/config/heuristics.py index 2c003fe64..f5f9ea24c 100644 --- a/spras/config/heuristics.py +++ b/spras/config/heuristics.py @@ -74,13 +74,13 @@ def validate_graph(self, graph: nx.DiGraph): def validate_graph_from_file(self, path: str | os.PathLike): # TODO: re-use from summary.py once we have a mixed/hypergraph library - G = nx.read_edgelist(path, data=(('weight', float), ('Direction', str)), create_using=nx.DiGraph) + G: nx.DiGraph = nx.read_edgelist(path, data=(('Rank', str), ('Direction', str)), create_using=nx.DiGraph) # We explicitly use `list` here to stop add_edge # from expanding our iterator infinitely. for source, target, data in list(G.edges(data=True)): if data["Direction"] == 'U': - G.add_edge(target, source, data) + G.add_edge(target, source, data=data) pass return self.validate_graph(G) From 7b290dcb636309534ec3ccdeae06d34a1c077f1b Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Thu, 6 Nov 2025 08:31:13 +0000 Subject: [PATCH 15/25] feat: interval and heuristic testing --- spras/config/heuristics.py | 33 +++++++++++++++++++------ spras/interval.py | 30 +++++++++++++--------- test/heuristics/__init__.py | 0 test/heuristics/fixtures/empty.txt | 0 test/heuristics/fixtures/nonempty.txt | 1 + test/heuristics/fixtures/undirected.txt | 1 + test/heuristics/test_heuristics.py | 26 +++++++++++++++++++ test/test_interval.py | 3 +++ 8 files changed, 76 insertions(+), 18 deletions(-) create mode 100644 test/heuristics/__init__.py create mode 100644 test/heuristics/fixtures/empty.txt create mode 100644 test/heuristics/fixtures/nonempty.txt create mode 100644 test/heuristics/fixtures/undirected.txt create mode 100644 test/heuristics/test_heuristics.py diff --git a/spras/config/heuristics.py b/spras/config/heuristics.py index f5f9ea24c..52c4839c6 100644 --- a/spras/config/heuristics.py +++ b/spras/config/heuristics.py @@ -6,25 +6,40 @@ from spras.interval import Interval from spras.statistics import compute_statistics, statistics_options +all = ['GraphHeuristicsError', 'GraphHeuristic'] -class GraphHeuristicError(RuntimeError): +class GraphHeuristicsError(RuntimeError): """ Represents an error arising from a graph algorithm output not meeting the necessary graph heuristisc. """ failed_heuristics: list[tuple[str, float | int, list[Interval]]] + @staticmethod + def format_failed_heuristic(heuristic: tuple[str, float | int, list[Interval]]) -> str: + name, desired, intervals = heuristic + if len(intervals) == 1: + interval_string = str(intervals[0]) + else: + formatted_intervals = ", ".join([str(interval) for interval in intervals]) + interval_string = f"one of the intervals ({formatted_intervals})" + return f"{name} expected {desired} in interval {interval_string}" @staticmethod def to_string(failed_heuristics: list[tuple[str, float | int, list[Interval]]]): - return f"The following heuristics failed: {failed_heuristics}" + formatted_heuristics = [ + GraphHeuristicsError.format_failed_heuristic(heuristic) for heuristic in failed_heuristics + ] + + formatted_heuristics = "\n".join([f"- {formatted_heuristics}" for heuristic in formatted_heuristics]) + return f"The following heuristics failed:\n{formatted_heuristics}" def __init__(self, failed_heuristics: list[tuple[str, float | int, list[Interval]]]): - super().__init__(GraphHeuristicError.to_string(failed_heuristics)) + super().__init__(GraphHeuristicsError.to_string(failed_heuristics)) self.failed_heuristics = failed_heuristics def __str__(self) -> str: - return GraphHeuristicError.to_string(self.failed_heuristics) + return GraphHeuristicsError.to_string(self.failed_heuristics) class GraphHeuristics(BaseModel): number_of_nodes: Interval | list[Interval] = [] @@ -54,7 +69,7 @@ def validate_graph(self, graph: nx.DiGraph): stats = compute_statistics( graph, - list(k for k, v in statistics_dictionary.items() if isinstance(v, list) and len(v) == 0) + list(k for k, v in statistics_dictionary.items() if not isinstance(v, list) or len(v) != 0) ) failed_heuristics: list[tuple[str, float | int, list[Interval]]] = [] @@ -63,16 +78,20 @@ def validate_graph(self, graph: nx.DiGraph): if not isinstance(intervals, list): intervals = [intervals] for interval in intervals: - if interval.mem(value): + if not interval.mem(value): failed_heuristics.append((key, value, intervals)) break if len(failed_heuristics) != 0: - raise GraphHeuristicError(failed_heuristics) + raise GraphHeuristicsError(failed_heuristics) model_config = ConfigDict(extra='forbid') def validate_graph_from_file(self, path: str | os.PathLike): + """ + Takes in a graph produced by PRM#parse_output, + and throws a GraphHeuristicsError if it fails the heuristics in `self`. + """ # TODO: re-use from summary.py once we have a mixed/hypergraph library G: nx.DiGraph = nx.read_edgelist(path, data=(('Rank', str), ('Direction', str)), create_using=nx.DiGraph) diff --git a/spras/interval.py b/spras/interval.py index 266997fd5..b65f87a7c 100644 --- a/spras/interval.py +++ b/spras/interval.py @@ -10,9 +10,10 @@ import tokenize from enum import Enum from io import BytesIO -from typing import Any, Optional, Self, cast +from typing import Any, ClassVar, Optional, Self, cast -from pydantic import BaseModel, model_validator +from pydantic import model_serializer, model_validator +from pydantic.dataclasses import dataclass class Operand(Enum): @@ -82,7 +83,10 @@ def combine(cls, left: Self, right: Self): case (Operand.GT, Operand.GT): return Operand.GT return None -class Interval(BaseModel): +@dataclass +class Interval: + EMPTY_STRING: ClassVar[str] = "{empty interval}" + lower: Optional[float] upper: Optional[float] lower_closed: bool @@ -110,10 +114,10 @@ def left_operand(cls, operand: Operand, num: float) -> Self: """Creates an interval whose operand is on the left (e.g. <300)""" match operand: case Operand.LT: return cls(lower=None, upper=num, lower_closed=False, upper_closed=False) - case Operand.LTE: return cls(lower=None, upper=num, lower_closed=True, upper_closed=False) + case Operand.LTE: return cls(lower=None, upper=num, lower_closed=False, upper_closed=True) case Operand.EQ: return cls.single(num) - case Operand.GTE: return cls(lower=num, upper=None, lower_closed=False, upper_closed=False) - case Operand.GT: return cls(lower=num, upper=None, lower_closed=False, upper_closed=True) + case Operand.GTE: return cls(lower=num, upper=None, lower_closed=True, upper_closed=False) + case Operand.GT: return cls(lower=num, upper=None, lower_closed=False, upper_closed=False) @classmethod def right_operand(cls, num: float, operand: Operand) -> Self: @@ -203,10 +207,10 @@ def is_id(idstr: str) -> bool: return idstr.isidentifier() ) def __str__(self) -> str: - if not self.lower and not self.upper: return "{empty interval}" - if not self.lower: + if self.lower is None and self.upper is None: return Interval.EMPTY_STRING + if self.lower is None: return Operand.LT.with_closed(self.upper_closed).value + " " + str(self.upper) - if not self.upper: + if self.upper is None: return str(self.lower) + " " + Operand.LT.with_closed(self.lower_closed).value if self.lower == self.upper and self.lower_closed and self.upper_closed: return str(self.lower) @@ -221,6 +225,10 @@ def __repr__(self) -> str: @model_validator(mode="before") @classmethod def from_literal(cls, data: Any) -> Any: - if isinstance(data, str): - return cls.from_string(data) + if isinstance(data, int) or isinstance(data, float) or isinstance(data, str): + return vars(cls.from_string(str(data))) return data + + @model_serializer(mode='plain') + def serialize_model(self) -> str: + return str(self) diff --git a/test/heuristics/__init__.py b/test/heuristics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/heuristics/fixtures/empty.txt b/test/heuristics/fixtures/empty.txt new file mode 100644 index 000000000..e69de29bb diff --git a/test/heuristics/fixtures/nonempty.txt b/test/heuristics/fixtures/nonempty.txt new file mode 100644 index 000000000..8e9f8ac96 --- /dev/null +++ b/test/heuristics/fixtures/nonempty.txt @@ -0,0 +1 @@ +A B 1 D diff --git a/test/heuristics/fixtures/undirected.txt b/test/heuristics/fixtures/undirected.txt new file mode 100644 index 000000000..627d30073 --- /dev/null +++ b/test/heuristics/fixtures/undirected.txt @@ -0,0 +1 @@ +A B 1 U diff --git a/test/heuristics/test_heuristics.py b/test/heuristics/test_heuristics.py new file mode 100644 index 000000000..3512915e0 --- /dev/null +++ b/test/heuristics/test_heuristics.py @@ -0,0 +1,26 @@ +from pathlib import Path +import pytest + +from spras.config.heuristics import GraphHeuristics, GraphHeuristicsError + +FIXTURES_DIR = Path('test', 'heuristics', 'fixtures') + +class TestHeuristics: + def parse(self, heuristics: dict) -> GraphHeuristics: + return GraphHeuristics.model_validate(heuristics) + + def test_nonempty(self): + self.parse({ 'number_of_nodes': '>0', 'number_of_edges': '1' } + ).validate_graph_from_file(FIXTURES_DIR / 'nonempty.txt') + + def test_empty(self): + self.parse({ 'number_of_nodes': '<1' } + ).validate_graph_from_file(FIXTURES_DIR / 'empty.txt') + + with pytest.raises(GraphHeuristicsError): + self.parse({ 'number_of_nodes': '0<' } + ).validate_graph_from_file(FIXTURES_DIR / 'empty.txt') + + def test_undirected(self): + self.parse({ 'number_of_nodes': '1 < x < 3', 'number_of_edges': 2 } + ).validate_graph_from_file(FIXTURES_DIR / 'undirected.txt') \ No newline at end of file diff --git a/test/test_interval.py b/test/test_interval.py index 840f8d057..1481d1a79 100644 --- a/test/test_interval.py +++ b/test/test_interval.py @@ -6,6 +6,9 @@ def test_number(self): assert Interval.single(5) == Interval(lower=5, upper=5, lower_closed=True, upper_closed=True) assert Interval.from_string("5") == Interval.single(5) + def test_interval_gt_0(self): + assert Interval.from_string(">0") == Interval(lower=0, upper=None, lower_closed=False, upper_closed=False) + def test_string_permutations(self): assert Interval.from_string("<5") == Interval.from_string("< 5") assert Interval.from_string("5<") == Interval.from_string("5 < x") From 4844fd6917bb1fe7ee7006d82bfa3e73c6830fbb Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Thu, 6 Nov 2025 08:33:32 +0000 Subject: [PATCH 16/25] style: fmt --- test/heuristics/test_heuristics.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/heuristics/test_heuristics.py b/test/heuristics/test_heuristics.py index 3512915e0..8011f5377 100644 --- a/test/heuristics/test_heuristics.py +++ b/test/heuristics/test_heuristics.py @@ -1,4 +1,5 @@ from pathlib import Path + import pytest from spras.config.heuristics import GraphHeuristics, GraphHeuristicsError @@ -16,11 +17,11 @@ def test_nonempty(self): def test_empty(self): self.parse({ 'number_of_nodes': '<1' } ).validate_graph_from_file(FIXTURES_DIR / 'empty.txt') - + with pytest.raises(GraphHeuristicsError): self.parse({ 'number_of_nodes': '0<' } ).validate_graph_from_file(FIXTURES_DIR / 'empty.txt') def test_undirected(self): self.parse({ 'number_of_nodes': '1 < x < 3', 'number_of_edges': 2 } - ).validate_graph_from_file(FIXTURES_DIR / 'undirected.txt') \ No newline at end of file + ).validate_graph_from_file(FIXTURES_DIR / 'undirected.txt') From 1ca730e4cd36e0542fbd90496d972997db340d19 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Tue, 13 Jan 2026 12:13:21 -0800 Subject: [PATCH 17/25] feat: snakemake-based summary generation --- Snakefile | 24 ++++++++++++++++++++---- spras/analysis/summary.py | 15 +++++++++------ spras/statistics.py | 28 ++-------------------------- 3 files changed, 31 insertions(+), 36 deletions(-) diff --git a/Snakefile b/Snakefile index cf075b0fa..060696c71 100644 --- a/Snakefile +++ b/Snakefile @@ -2,10 +2,11 @@ import os from spras import runner import shutil import yaml -from spras.dataset import Dataset -from spras.evaluation import Evaluation from spras.analysis import ml, summary, cytoscape import spras.config.config as _config +from spras.dataset import Dataset +from spras.evaluation import Evaluation +from spras.statistics import from_edgelist, statistics_computation, statistics_options # Snakemake updated the behavior in the 6.5.0 release https://github.com/snakemake/snakemake/pull/1037 # and using the wrong separator prevents Snakemake from matching filenames to the rules that can produce them @@ -310,18 +311,33 @@ rule viz_cytoscape: run: cytoscape.run_cytoscape(input.pathways, output.session, container_settings) +for keys, values in statistics_computation.items(): + pythonic_name = 'generate_' + '_and_'.join([key.lower().replace(' ', '_') for key in keys]) + rule: + name: pythonic_name + input: pathway_file = rules.reconstruct.output.pathway_file + output: [SEP.join([out_dir, '{dataset}-{algorithm}-{params}', 'statistics', f'{key}.txt']) for key in keys] + run: + (Path(input.pathway_file).parent / 'statistics').mkdir(exist_ok=True) + graph = from_edgelist(input.pathway_file) + for computed, output in zip(values(graph), output): + Path(output).write_text(str(computed)) # Write a single summary table for all pathways for each dataset rule summary_table: input: # Collect all pathways generated for the dataset pathways = expand('{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params), - dataset_file = SEP.join([out_dir, 'dataset-{dataset}-merged.pickle']) + dataset_file = SEP.join([out_dir, 'dataset-{dataset}-merged.pickle']), + # Collect all possible options + statistics = expand( + '{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}statistics{sep}{statistic}.txt', + out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, statistic=statistics_options) output: summary_table = SEP.join([out_dir, '{dataset}-pathway-summary.txt']) run: # Load the node table from the pickled dataset file node_table = Dataset.from_file(input.dataset_file).node_table - summary_df = summary.summarize_networks(input.pathways, node_table, algorithm_params, algorithms_with_params) + summary_df = summary.summarize_networks(input.pathways, node_table, algorithm_params, algorithms_with_params, input.statistics) summary_df.to_csv(output.summary_table, sep='\t', index=False) # Cluster the output pathways for each dataset diff --git a/spras/analysis/summary.py b/spras/analysis/summary.py index cdffe0f68..0bd025aa4 100644 --- a/spras/analysis/summary.py +++ b/spras/analysis/summary.py @@ -1,14 +1,14 @@ +import ast from pathlib import Path from typing import Iterable -import networkx as nx import pandas as pd -from spras.statistics import compute_statistics, statistics_options +from spras.statistics import from_edgelist def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, algo_params: dict[str, dict], - algo_with_params: list) -> pd.DataFrame: + algo_with_params: list, statistics_files: list) -> pd.DataFrame: """ Generate a table that aggregates summary information about networks in file_paths, including which nodes are present in node_table columns. Network directionality is ignored and all edges are treated as undirected. The order of the @@ -44,15 +44,16 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg lines = f.readlines()[1:] # skip the header line # directed or mixed graphs are parsed and summarized as an undirected graph - nw = nx.read_edgelist(lines, data=(('weight', float), ('Direction', str))) + nw = from_edgelist(lines) # Save the network name, number of nodes, number edges, and number of connected components nw_name = str(file_path) - graph_statistics = compute_statistics(nw, statistics_options) + # We use literal_eval here to easily coerce to either ints or floats, depending. + graph_statistics = [ast.literal_eval(Path(file).read_text()) for file in statistics_files] # Initialize list to store current network information - cur_nw_info = [nw_name, *graph_statistics.values()] + cur_nw_info = [nw_name, *graph_statistics] # Iterate through each node property and save the intersection with the current network for node_list in nodes_by_col: @@ -73,6 +74,8 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg # Save the current network information to the network summary list nw_info.append(cur_nw_info) + # Get the list of statistic names by their file names + statistics_options = [Path(file).stem for file in statistics_files] # Prepare column names col_names = ['Name', *statistics_options] col_names.extend(nodes_by_col_labs) diff --git a/spras/statistics.py b/spras/statistics.py index 222051d23..7bc8253c6 100644 --- a/spras/statistics.py +++ b/spras/statistics.py @@ -63,29 +63,5 @@ def compute_on_cc(directed_graph: nx.DiGraph) -> tuple[int, float]: # All of the keys inside statistics_computation, flattened. statistics_options: list[str] = list(itertools.chain(*(list(key) for key in statistics_computation.keys()))) -def compute_statistics(graph: nx.DiGraph, statistics: list[str]) -> dict[str, float | int]: - """ - Computes `statistics` for a graph corresponding to the top-level `statistics` dictionary - in this file. - """ - - # early-scan cutoff for statistics: - # we want to err as soon as possible - for stat in statistics: - if stat not in statistics_options: - raise RuntimeError(f"Statistic {stat} not a computable statistics! Available statistics: {statistics_options}") - - # now, we can compute statistics only - computed_statistics: dict[str, float | int] = dict() - for statistic_tuple, compute in statistics_computation.items(): - # when we want them - if not set(statistic_tuple).isdisjoint(set(statistics)): - computed_tuple = compute(graph) - assert len(statistic_tuple) == len(computed_tuple), f"bad tuple length for {statistic_tuple}" - - current_computed_statistics = zip(statistic_tuple, computed_tuple, strict=True) - for stat, value in current_computed_statistics: - computed_statistics[stat] = value - - # (and return only the statistics we wanted) - return {key: computed_statistics[key] for key in statistics} +def from_edgelist(lines) -> nx.Graph: + return nx.read_edgelist(lines, data=(('weight', float), ('Direction', str))) From d67186dcd5679c44264b24836d86f25816aecb52 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Tue, 13 Jan 2026 12:19:43 -0800 Subject: [PATCH 18/25] fix(Snakefile): use parse_output for edgelist parsing --- Snakefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Snakefile b/Snakefile index 060696c71..532be6fe7 100644 --- a/Snakefile +++ b/Snakefile @@ -315,7 +315,7 @@ for keys, values in statistics_computation.items(): pythonic_name = 'generate_' + '_and_'.join([key.lower().replace(' ', '_') for key in keys]) rule: name: pythonic_name - input: pathway_file = rules.reconstruct.output.pathway_file + input: pathway_file = rules.parse_output.output.standardized_file output: [SEP.join([out_dir, '{dataset}-{algorithm}-{params}', 'statistics', f'{key}.txt']) for key in keys] run: (Path(input.pathway_file).parent / 'statistics').mkdir(exist_ok=True) From fd483c3af9ab15bb5b1717b6a33b1ae338b25472 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Tue, 13 Jan 2026 12:37:46 -0800 Subject: [PATCH 19/25] fix: parse edgelist with rank, embed header skip inside from_edgelist this had incorrect behavior ? --- Snakefile | 4 ++-- spras/analysis/summary.py | 7 ++----- spras/statistics.py | 7 +++++-- test/analysis/test_summary.py | 24 ++++++++++++------------ 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/Snakefile b/Snakefile index 532be6fe7..9673b80b0 100644 --- a/Snakefile +++ b/Snakefile @@ -6,7 +6,7 @@ from spras.analysis import ml, summary, cytoscape import spras.config.config as _config from spras.dataset import Dataset from spras.evaluation import Evaluation -from spras.statistics import from_edgelist, statistics_computation, statistics_options +from spras.statistics import from_output_pathway, statistics_computation, statistics_options # Snakemake updated the behavior in the 6.5.0 release https://github.com/snakemake/snakemake/pull/1037 # and using the wrong separator prevents Snakemake from matching filenames to the rules that can produce them @@ -319,7 +319,7 @@ for keys, values in statistics_computation.items(): output: [SEP.join([out_dir, '{dataset}-{algorithm}-{params}', 'statistics', f'{key}.txt']) for key in keys] run: (Path(input.pathway_file).parent / 'statistics').mkdir(exist_ok=True) - graph = from_edgelist(input.pathway_file) + graph = from_output_pathway(input.pathway_file) for computed, output in zip(values(graph), output): Path(output).write_text(str(computed)) diff --git a/spras/analysis/summary.py b/spras/analysis/summary.py index 0bd025aa4..1f627493f 100644 --- a/spras/analysis/summary.py +++ b/spras/analysis/summary.py @@ -4,7 +4,7 @@ import pandas as pd -from spras.statistics import from_edgelist +from spras.statistics import from_output_pathway def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, algo_params: dict[str, dict], @@ -40,11 +40,8 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg # Iterate through each network file path for index, file_path in enumerate(sorted(file_paths)): - with open(file_path, 'r') as f: - lines = f.readlines()[1:] # skip the header line - # directed or mixed graphs are parsed and summarized as an undirected graph - nw = from_edgelist(lines) + nw = from_output_pathway(file_path) # Save the network name, number of nodes, number edges, and number of connected components nw_name = str(file_path) diff --git a/spras/statistics.py b/spras/statistics.py index 7bc8253c6..5399da390 100644 --- a/spras/statistics.py +++ b/spras/statistics.py @@ -63,5 +63,8 @@ def compute_on_cc(directed_graph: nx.DiGraph) -> tuple[int, float]: # All of the keys inside statistics_computation, flattened. statistics_options: list[str] = list(itertools.chain(*(list(key) for key in statistics_computation.keys()))) -def from_edgelist(lines) -> nx.Graph: - return nx.read_edgelist(lines, data=(('weight', float), ('Direction', str))) +def from_output_pathway(lines) -> nx.Graph: + with open(lines, 'r') as f: + lines = f.readlines()[1:] + + return nx.read_edgelist(lines, data=(('Rank', int), ('Direction', str))) diff --git a/test/analysis/test_summary.py b/test/analysis/test_summary.py index 57f1f6012..8618f0a2f 100644 --- a/test/analysis/test_summary.py +++ b/test/analysis/test_summary.py @@ -12,9 +12,9 @@ # - 'NODEID' is required as the first column label in the node table # - file_paths must be an iterable, even if a single file path is passed -INPUT_DIR = 'test/analysis/input/' -OUT_DIR = 'test/analysis/output/' -EXPECT_DIR = 'test/analysis/expected_output/' +INPUT_DIR = Path('test', 'analysis', 'input') +OUT_DIR = Path('test', 'analysis', 'output') +EXPECT_DIR = Path('test', 'analysis', 'expected_output') class TestSummary: @@ -35,14 +35,14 @@ def test_example_networks(self): } example_dataset = Dataset(example_dict) example_node_table = example_dataset.node_table - config.init_from_file(INPUT_DIR + "config.yaml") + config.init_from_file(INPUT_DIR / "config.yaml") algorithm_params = config.config.algorithm_params algorithms_with_params = [f'{algorithm}-params-{params_hash}' for algorithm, param_combos in algorithm_params.items() for params_hash in param_combos.keys()] - example_network_files = Path(INPUT_DIR + "example").glob("*.txt") # must be path to use .glob() + example_network_files = Path(INPUT_DIR, "example").glob("*.txt") - out_path = Path(OUT_DIR + "test_example_summary.txt") + out_path = Path(OUT_DIR, "test_example_summary.txt") out_path.unlink(missing_ok=True) summarize_example = summarize_networks(example_network_files, example_node_table, algorithm_params, algorithms_with_params) @@ -51,7 +51,7 @@ def test_example_networks(self): # Comparing the dataframes directly with equals does not match because of how the parameter # combinations column is loaded from disk. Therefore, write both to disk and compare the files. - assert filecmp.cmp(out_path, EXPECT_DIR + "expected_example_summary.txt", shallow=False) + assert filecmp.cmp(out_path, EXPECT_DIR / "expected_example_summary.txt", shallow=False) def test_egfr_networks(self): """Test data from EGFR workflow""" @@ -64,14 +64,14 @@ def test_egfr_networks(self): egfr_dataset = Dataset(egfr_dict) egfr_node_table = egfr_dataset.node_table - config.init_from_file(INPUT_DIR + "egfr.yaml") + config.init_from_file(INPUT_DIR / "egfr.yaml") algorithm_params = config.config.algorithm_params algorithms_with_params = [f'{algorithm}-params-{params_hash}' for algorithm, param_combos in algorithm_params.items() for params_hash in param_combos.keys()] - egfr_network_files = Path(INPUT_DIR + "egfr").glob("*.txt") # must be path to use .glob() + egfr_network_files = Path(INPUT_DIR, "egfr").glob("*.txt") # must be path to use .glob() - out_path = Path(OUT_DIR + "test_egfr_summary.txt") + out_path = Path(OUT_DIR, "test_egfr_summary.txt") out_path.unlink(missing_ok=True) summarize_egfr = summarize_networks(egfr_network_files, egfr_node_table, algorithm_params, algorithms_with_params) @@ -80,7 +80,7 @@ def test_egfr_networks(self): # Comparing the dataframes directly with equals does not match because of how the parameter # combinations column is loaded from disk. Therefore, write both to disk and compare the files. - assert filecmp.cmp(out_path, EXPECT_DIR + "expected_egfr_summary.txt", shallow=False) + assert filecmp.cmp(out_path, EXPECT_DIR / "expected_egfr_summary.txt", shallow=False) def test_load_dataset_dict(self): """Test loading files from dataset_dict""" @@ -95,7 +95,7 @@ def test_load_dataset_dict(self): # node_table contents are not generated consistently in the same order, # so we will check that the contents are the same, but row order doesn't matter - expected_node_table = pd.read_csv((EXPECT_DIR + "expected_node_table.txt"), sep="\t") + expected_node_table = pd.read_csv((EXPECT_DIR / "expected_node_table.txt"), sep="\t") # ignore 'NODEID' column because this changes each time upon new generation cols_to_compare = [col for col in example_node_table.columns if col != "NODEID"] From fd5046f165f3ab29e6e154f29f4eab7316a0fb45 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Tue, 13 Jan 2026 12:55:38 -0800 Subject: [PATCH 20/25] style: fmt --- spras/statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spras/statistics.py b/spras/statistics.py index 5399da390..342f1a5e2 100644 --- a/spras/statistics.py +++ b/spras/statistics.py @@ -66,5 +66,5 @@ def compute_on_cc(directed_graph: nx.DiGraph) -> tuple[int, float]: def from_output_pathway(lines) -> nx.Graph: with open(lines, 'r') as f: lines = f.readlines()[1:] - + return nx.read_edgelist(lines, data=(('Rank', int), ('Direction', str))) From 79cf748b9efe78dff51e69963591ef267a3eb0c8 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Tue, 13 Jan 2026 13:17:48 -0800 Subject: [PATCH 21/25] chore: mention statistics_files param --- spras/analysis/summary.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spras/analysis/summary.py b/spras/analysis/summary.py index 1f627493f..e5c0b1f73 100644 --- a/spras/analysis/summary.py +++ b/spras/analysis/summary.py @@ -18,6 +18,7 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg @param algo_params: a nested dict mapping algorithm names to dicts that map parameter hashes to parameter combinations. @param algo_with_params: a list of -params- combinations + @param statistics_files: a list of statistic files with the computed statistics. @return: pandas DataFrame with summary information """ # Ensure that NODEID is the first column From 85e0ea8a020133186074b7663c83fa4c1a253a9b Mon Sep 17 00:00:00 2001 From: "Tristan F.-R." Date: Sat, 14 Feb 2026 01:09:44 +0000 Subject: [PATCH 22/25] docs: more info on summary & statistics --- Snakefile | 3 ++ spras/analysis/summary.py | 66 ++------------------------------------- spras/statistics.py | 10 ++++-- 3 files changed, 14 insertions(+), 65 deletions(-) diff --git a/Snakefile b/Snakefile index e4b829dee..e6d9204a5 100644 --- a/Snakefile +++ b/Snakefile @@ -310,9 +310,12 @@ rule viz_cytoscape: run: cytoscape.run_cytoscape(input.pathways, output.session, container_settings) +# We generate new Snakemake rules for every statistic +# to allow parallel and lazy computation of individual statistics for keys, values in statistics_computation.items(): pythonic_name = 'generate_' + '_and_'.join([key.lower().replace(' ', '_') for key in keys]) rule: + # (See https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#procedural-rule-definition) name: pythonic_name input: pathway_file = rules.parse_output.output.standardized_file output: [SEP.join([out_dir, '{dataset}-{algorithm}-{params}', 'statistics', f'{key}.txt']) for key in keys] diff --git a/spras/analysis/summary.py b/spras/analysis/summary.py index 2907b0e34..d2059bb21 100644 --- a/spras/analysis/summary.py +++ b/spras/analysis/summary.py @@ -47,7 +47,8 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg # Save the network name, number of nodes, number edges, and number of connected components nw_name = str(file_path) - # We use literal_eval here to easily coerce to either ints or floats, depending. + # We use ast.literal_eval here to convert statistic file outputs to ints or floats depending on their string representation. + # (e.g. "5.0" -> float(5.0), while "5" -> int(5).) graph_statistics = [ast.literal_eval(Path(file).read_text()) for file in statistics_files] # Initialize list to store current network information @@ -89,65 +90,4 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg return nw_info -def degree(g): - return dict(g.degree) - -# TODO: redo .run code to work on mixed graphs -# stats is just a list of functions to apply to the graph. -# They should take as input a networkx graph or digraph but may have any output. -# stats = [degree, nx.clustering, nx.betweenness_centrality] - - -# def produce_statistics(g: nx.Graph, s=None) -> dict: -# global stats -# if s is not None: -# stats = s -# d = dict() -# for s in stats: -# sname = s.__name__ -# d[sname] = s(g) -# return d - - -# def load_graph(path: str) -> nx.Graph: -# g = nx.read_edgelist(path, data=(('weight', float), ('Direction',str))) -# return g - - -# def save(data, pth): -# fout = open(pth, 'w') -# fout.write('#node\t%s\n' % '\t'.join([s.__name__ for s in stats])) -# for node in data[stats[0].__name__]: -# row = [data[s.__name__][node] for s in stats] -# fout.write('%s\t%s\n' % (node, '\t'.join([str(d) for d in row]))) -# fout.close() - - -# def run(infile: str, outfile: str) -> None: -# """ -# run function that wraps above functions. -# """ -# # if output directory doesn't exist, make it. -# outdir = os.path.dirname(outfile) -# if not os.path.exists(outdir): -# os.makedirs(outdir) - -# # load graph, produce stats, and write to human-readable file. -# g = load_graph(infile) -# dat = produce_statistics(g) -# save(dat, outfile) - - -# def main(argv): -# """ -# for testing -# """ -# g = load_graph(argv[1]) -# print(g.nodes) -# dat = produce_statistics(g) -# print(dat) -# save(dat, argv[2]) - - -# if __name__ == '__main__': -# main(sys.argv) +# TODO: redo the above code to work on mixed graphs diff --git a/spras/statistics.py b/spras/statistics.py index 342f1a5e2..9c510a151 100644 --- a/spras/statistics.py +++ b/spras/statistics.py @@ -4,6 +4,10 @@ We allow for arbitrary computation of any specific statistic on some graph, computing more than necessary if we have dependencies. See the top level `statistics_computation` dictionary for usage. + +To make the statistics allow directed graph input, they will always take +in a networkx.DiGraph, which contains even more information, even though +the underlying graph may be just as easily represented by networkx.Graph. """ import itertools @@ -25,6 +29,9 @@ def compute_degree(graph: nx.DiGraph) -> tuple[int, float]: return max(degrees), median(degrees) def compute_on_cc(directed_graph: nx.DiGraph) -> tuple[int, float]: + # We convert our directed_graph to an undirected graph as networkx (reasonably) does + # not allow for computing the connected components of a directed graph, but the connected + # component count still is a useful statistic for us. graph: nx.Graph = directed_graph.to_undirected() cc = list(nx.connected_components(graph)) # Save the max diameter @@ -49,13 +56,12 @@ def compute_on_cc(directed_graph: nx.DiGraph) -> tuple[int, float]: return max_diameter, avg_path_len -# The type signature on here is quite bad. I would like to say that an n-tuple has n-outputs. +# The type signature here is meant to be 'an n-tuple has n-outputs.' statistics_computation: dict[tuple[str, ...], Callable[[nx.DiGraph], tuple[float | int, ...]]] = { ('Number of nodes',): lambda graph : (graph.number_of_nodes(),), ('Number of edges',): lambda graph : (graph.number_of_edges(),), ('Number of connected components',): lambda graph : (nx.number_connected_components(graph.to_undirected()),), ('Density',): lambda graph : (nx.density(graph),), - ('Max degree', 'Median degree'): compute_degree, ('Max diameter', 'Average path length'): compute_on_cc, } From 804849a4c7800d1a62eb67b11d0ded2b996e1e1d Mon Sep 17 00:00:00 2001 From: "Tristan F.-R." Date: Sat, 14 Feb 2026 01:12:19 +0000 Subject: [PATCH 23/25] style: fmt --- spras/statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spras/statistics.py b/spras/statistics.py index 9c510a151..251fecca2 100644 --- a/spras/statistics.py +++ b/spras/statistics.py @@ -5,7 +5,7 @@ computing more than necessary if we have dependencies. See the top level `statistics_computation` dictionary for usage. -To make the statistics allow directed graph input, they will always take +To make the statistics allow directed graph input, they will always take in a networkx.DiGraph, which contains even more information, even though the underlying graph may be just as easily represented by networkx.Graph. """ From ae61e5752302234e3cd1dc5e9a751bfb33856213 Mon Sep 17 00:00:00 2001 From: "Tristan F.-R." Date: Fri, 17 Apr 2026 19:27:21 +0000 Subject: [PATCH 24/25] Merge branch 'umain' into generate-all-inputs --- Snakefile | 2 +- spras/analysis/summary.py | 6 ++---- test/analysis/input/egfr.yaml | 3 +++ test/analysis/input/example.yaml | 3 +++ test/analysis/test_summary.py | 7 ++++--- 5 files changed, 13 insertions(+), 8 deletions(-) diff --git a/Snakefile b/Snakefile index 166c8b2ff..ea935bec2 100644 --- a/Snakefile +++ b/Snakefile @@ -334,7 +334,7 @@ rule summary_table: # Collect all pathways generated for the dataset pathways = expand('{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params), dataset_file = SEP.join([out_dir, 'dataset-{dataset}-merged.pickle']), - # Collect all possible options + # Collect all possible statistics into a dictionary statistics = expand( '{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}statistics{sep}{statistic}.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, statistic=statistics_options) diff --git a/spras/analysis/summary.py b/spras/analysis/summary.py index 237f7c5ba..cac952403 100644 --- a/spras/analysis/summary.py +++ b/spras/analysis/summary.py @@ -1,5 +1,6 @@ import ast import json +import os from pathlib import Path from typing import Iterable @@ -9,7 +10,7 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, algo_params: dict[str, dict], - algo_with_params: list[str], statistics_files: list) -> pd.DataFrame: + algo_with_params: list[str], statistics_files: list[str | os.PathLike]) -> pd.DataFrame: """ Generate a table that aggregates summary information about networks in file_paths, including which nodes are present in node_table columns. Network directionality is ignored and all edges are treated as undirected. The order of the @@ -90,6 +91,3 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg ) return nw_info - - -# TODO: redo the above code to work on mixed graphs diff --git a/test/analysis/input/egfr.yaml b/test/analysis/input/egfr.yaml index c9ed5f735..77c3bedf3 100644 --- a/test/analysis/input/egfr.yaml +++ b/test/analysis/input/egfr.yaml @@ -31,3 +31,6 @@ datasets: reconstruction_settings: locations: reconstruction_dir: "test/analysis/input/run/egfr" +analysis: + summary: + include: true diff --git a/test/analysis/input/example.yaml b/test/analysis/input/example.yaml index 1a4514c00..15f4a69b4 100644 --- a/test/analysis/input/example.yaml +++ b/test/analysis/input/example.yaml @@ -48,3 +48,6 @@ gold_standards: reconstruction_settings: locations: reconstruction_dir: "test/analysis/input/run/example" +analysis: + summary: + include: true diff --git a/test/analysis/test_summary.py b/test/analysis/test_summary.py index b548b8087..f6e940cf5 100644 --- a/test/analysis/test_summary.py +++ b/test/analysis/test_summary.py @@ -32,7 +32,7 @@ def snakemake_output(request): param = request.param subprocess.run(["snakemake", "--cores", "1", "--configfile", f"test/analysis/input/{param}.yaml"]) yield param # this runs the test itself: once this is passed, we go to test cleanup. - shutil.rmtree(f"test/analysis/input/run/{param}") + # shutil.rmtree(f"test/analysis/input/run/{param}") class TestSummary: @classmethod @@ -56,11 +56,12 @@ def test_example_networks(self, snakemake_output): algorithms_with_params = [f'{algorithm}-params-{params_hash}' for algorithm, param_combos in algorithm_params.items() for params_hash in param_combos.keys()] - example_network_files = (INPUT_DIR / "run" / snakemake_output).rglob("pathway.txt") + network_files = (INPUT_DIR / "run" / snakemake_output).rglob("pathway.txt") + statistics_files = (INPUT_DIR / "run" / snakemake_output).rglob("**/statistics/**") out_path = Path(OUT_DIR, f"test_{snakemake_output}_summary.txt") out_path.unlink(missing_ok=True) - summarize_out = summarize_networks(example_network_files, example_node_table, algorithm_params, + summarize_out = summarize_networks(network_files, example_node_table, algorithm_params, algorithms_with_params) # We do some post-processing to ensure that we get a stable summarize_out, since the attached hash # is subject to variation (especially in testing) whenever the SPRAS commit revision gets changed From 4fe949d89d23501c3d45edab93800124ce656177 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Sat, 25 Apr 2026 08:24:16 +0000 Subject: [PATCH 25/25] refactor: use dictionaries instead of a flat list along with proper Snakemake procedural rule usage --- Snakefile | 26 +++++++++++++++++++------- spras/analysis/summary.py | 24 ++++++++++++++++-------- spras/statistics.py | 5 ++--- test/analysis/test_summary.py | 7 ++++--- 4 files changed, 41 insertions(+), 21 deletions(-) diff --git a/Snakefile b/Snakefile index ea935bec2..f294efd1d 100644 --- a/Snakefile +++ b/Snakefile @@ -315,34 +315,46 @@ rule viz_cytoscape: # We generate new Snakemake rules for every statistic # to allow parallel and lazy computation of individual statistics -for keys, values in statistics_computation.items(): +for keys in statistics_computation.keys(): pythonic_name = 'generate_' + '_and_'.join([key.lower().replace(' ', '_') for key in keys]) rule: # (See https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#procedural-rule-definition) name: pythonic_name input: pathway_file = rules.parse_output.output.standardized_file output: [SEP.join([out_dir, '{dataset}-{algorithm}-{params}', 'statistics', f'{key}.txt']) for key in keys] + # It is very tempting to use `.items()` instead of `.keys()` above, but + # We instead need to pass keys in via parameters, else the job would use the latest values in the statistics_computation. + # More info is in the procedural rule link ab + params: statistics_names=keys run: (Path(input.pathway_file).parent / 'statistics').mkdir(exist_ok=True) graph = from_output_pathway(input.pathway_file) - for computed, output in zip(values(graph), output): + for computed, output in zip(statistics_computation[params.statistics_names](graph), output): Path(output).write_text(str(computed)) +# We isolate this to a separate input function, as we want to preserve the dictionary structure +def summary_files(wildcards): + return { + algorithm_param: expand( + '{out_dir}{sep}{dataset}-{algorithm_param}{sep}statistics{sep}{statistic}.txt', + out_dir=out_dir, sep=SEP, algorithm_param=algorithm_param, statistic=statistics_options, + dataset=wildcards.dataset + ) for algorithm_param in algorithms_with_params + } + # Write a single summary table for all pathways for each dataset rule summary_table: input: # Collect all pathways generated for the dataset pathways = expand('{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params), dataset_file = SEP.join([out_dir, 'dataset-{dataset}-merged.pickle']), - # Collect all possible statistics into a dictionary - statistics = expand( - '{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}statistics{sep}{statistic}.txt', - out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params, statistic=statistics_options) + # Collect all possible statistics from the `summary_files` dictionary-based input function + statistics = lambda wildcards: flatten(list(summary_files(wildcards).values())) output: summary_table = SEP.join([out_dir, '{dataset}-pathway-summary.txt']) run: # Load the node table from the pickled dataset file node_table = Dataset.from_file(input.dataset_file).node_table - summary_df = summary.summarize_networks(input.pathways, node_table, algorithm_params, algorithms_with_params, input.statistics) + summary_df = summary.summarize_networks(input.pathways, node_table, algorithm_params, algorithms_with_params, summary_files(wildcards)) summary_df.to_csv(output.summary_table, sep='\t', index=False) # Cluster the output pathways for each dataset diff --git a/spras/analysis/summary.py b/spras/analysis/summary.py index cac952403..bdec9baca 100644 --- a/spras/analysis/summary.py +++ b/spras/analysis/summary.py @@ -1,16 +1,17 @@ import ast +import itertools import json import os from pathlib import Path -from typing import Iterable +from typing import Iterable, Mapping import pandas as pd -from spras.statistics import from_output_pathway +from spras.statistics import from_output_pathway, statistics_options def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, algo_params: dict[str, dict], - algo_with_params: list[str], statistics_files: list[str | os.PathLike]) -> pd.DataFrame: + algo_with_params: list[str], statistics_files: Mapping[str, Iterable[str | os.PathLike]]) -> pd.DataFrame: """ Generate a table that aggregates summary information about networks in file_paths, including which nodes are present in node_table columns. Network directionality is ignored and all edges are treated as undirected. The order of the @@ -20,7 +21,7 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg @param algo_params: a nested dict mapping algorithm names to dicts that map parameter hashes to parameter combinations. @param algo_with_params: a list of -params- combinations - @param statistics_files: a list of statistic files with the computed statistics. + @param statistics_files: a dictionary from algo_with_params to lists of statistic files with the computed statistics. @return: pandas DataFrame with summary information """ # Ensure that NODEID is the first column @@ -51,7 +52,11 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg # We use ast.literal_eval here to convert statistic file outputs to ints or floats depending on their string representation. # (e.g. "5.0" -> float(5.0), while "5" -> int(5).) - graph_statistics = [ast.literal_eval(Path(file).read_text()) for file in statistics_files] + graph_statistics = [ + ast.literal_eval(Path(file).read_text()) for file in + # along with sorting to keep the output stable (this happens again) + sorted(statistics_files[algo_with_params[index]], key=lambda x: statistics_options.index(Path(x).stem)) + ] # Initialize list to store current network information cur_nw_info = [nw_name, *graph_statistics] @@ -76,10 +81,13 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg # Save the current network information to the network summary list nw_info.append(cur_nw_info) - # Get the list of statistic names by their file names - statistics_options = [Path(file).stem for file in statistics_files] + # Get the list of statistic names by their file names (via finding all requested statistics in the provided files) + current_statistics_options = sorted( + set(Path(file).stem for file in itertools.chain(*statistics_files.values())), + key=lambda x: statistics_options.index(x) + ) # Prepare column names - col_names = ['Name', *statistics_options] + col_names = ['Name', *current_statistics_options] col_names.extend(nodes_by_col_labs) col_names.append('Parameter combination') diff --git a/spras/statistics.py b/spras/statistics.py index 251fecca2..f303c8a8f 100644 --- a/spras/statistics.py +++ b/spras/statistics.py @@ -71,6 +71,5 @@ def compute_on_cc(directed_graph: nx.DiGraph) -> tuple[int, float]: def from_output_pathway(lines) -> nx.Graph: with open(lines, 'r') as f: - lines = f.readlines()[1:] - - return nx.read_edgelist(lines, data=(('Rank', int), ('Direction', str))) + next(f) # skip the header line + return nx.read_edgelist(f, data=(('Rank', int), ('Direction', str)), delimiter='\t') diff --git a/test/analysis/test_summary.py b/test/analysis/test_summary.py index f6e940cf5..303970997 100644 --- a/test/analysis/test_summary.py +++ b/test/analysis/test_summary.py @@ -1,5 +1,4 @@ import filecmp -import shutil import subprocess from pathlib import Path @@ -57,12 +56,14 @@ def test_example_networks(self, snakemake_output): algorithm_params.items() for params_hash in param_combos.keys()] network_files = (INPUT_DIR / "run" / snakemake_output).rglob("pathway.txt") - statistics_files = (INPUT_DIR / "run" / snakemake_output).rglob("**/statistics/**") + statistics_folders = [Path(file) for file in (INPUT_DIR / "run" / snakemake_output).rglob("**/statistics") if Path(file).name == "statistics"] + # We do some string fiddling here to make sure the folder matches up with algorithms_with_params. This may be susceptible to a good refactor. + statistics_files = {"-".join(folder.parent.stem.split("-")[1:]): list(folder.glob("*.txt")) for folder in statistics_folders} out_path = Path(OUT_DIR, f"test_{snakemake_output}_summary.txt") out_path.unlink(missing_ok=True) summarize_out = summarize_networks(network_files, example_node_table, algorithm_params, - algorithms_with_params) + algorithms_with_params, statistics_files) # We do some post-processing to ensure that we get a stable summarize_out, since the attached hash # is subject to variation (especially in testing) whenever the SPRAS commit revision gets changed summarize_out["Parameter combination"] = summarize_out["Parameter combination"].astype(str)