From fe9c7e668337b9da238b51b8f5eeeeb93724aafe Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 22 Apr 2026 23:57:48 +0000 Subject: [PATCH 1/3] refactor: updated subroutine names --- .../merge_proofreading/merge_datasets.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index 383e5483..31914dbb 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -144,12 +144,12 @@ def load_fragment_graphs( graph.load(swc_pointer) # Remove groundtruth skeletons - for swc_id in graph.get_swc_ids(): + for swc_id in graph.swc_ids(): if swc_id.lower().startswith("n"): component_id = util.find_key( graph.component_id_to_swc_id, swc_id ) - nodes = graph.get_nodes_with_component_id(component_id) + nodes = graph.nodes_with_component_id(component_id) graph.remove_nodes(nodes, relabel_nodes=False) # Remove fragments excluded from merge sites @@ -158,7 +158,7 @@ def load_fragment_graphs( segment_ids = set(merge_sites["segment_id"].unique()) for nodes in map(list, list(nx.connected_components(graph))): node = util.sample_once(nodes) - segment_id = graph.get_node_segment_id(node) + segment_id = graph.node_segment_id(node) if segment_id not in segment_ids: graph.remove_nodes(nodes, relabel_nodes=False) graph.relabel_nodes() @@ -271,7 +271,7 @@ def remove_nonindexed_fragments(self, idxs): # Find fragment containing site if pair not in visited: - nodes = self.graphs[brain_id].get_nodes_with_segment_id(segment_id) + nodes = self.graphs[brain_id].nodes_with_segment_id(segment_id) self.graphs[brain_id].remove_nodes(nodes, False) visited.add(pair) @@ -425,13 +425,13 @@ def get_random_negative_site(self): # Sample node if outcome < 0.4: # Any node - node = util.sample_once(list(self.graphs[brain_id].nodes)) + node = util.sample_once(self.graphs[brain_id].nodes) #elif outcome < 0.5: # # Node close to soma # node = self.sample_node_nearby_soma(brain_id) elif outcome < 0.8: # Branching node - branching_nodes = self.graphs[brain_id].get_branchings() + branching_nodes = self.graphs[brain_id].branching_nodes() if len(branching_nodes) > 0: node = util.sample_once(branching_nodes) else: @@ -439,7 +439,7 @@ def get_random_negative_site(self): continue else: # Branching node from GT - branching_nodes = self.gt_graphs[brain_id].get_branchings() + branching_nodes = self.gt_graphs[brain_id].branching_nodes() node = util.sample_once(branching_nodes) subgraph = self.gt_graphs[brain_id].rooted_subgraph( node, self.subgraph_radius @@ -828,7 +828,7 @@ def add_examples(): for brain_id, graph in self.graphs.items(): # Filter branching nodes near other branching nodes nodes = list() - for i in graph.get_branchings(): + for i in graph.branching_nodes(): is_branchy = self.check_nearby_branching(brain_id, i) if not is_branchy and graph.degree[i] == 3: nodes.append(i) From 7d8007a3641f039a31a8430e3ebe5afe10eafb61 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 23 Apr 2026 00:27:31 +0000 Subject: [PATCH 2/3] bug: updated swc_id->component_id --- src/neuron_proofreader/merge_proofreading/merge_datasets.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index 31914dbb..347a687e 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -146,9 +146,7 @@ def load_fragment_graphs( # Remove groundtruth skeletons for swc_id in graph.swc_ids(): if swc_id.lower().startswith("n"): - component_id = util.find_key( - graph.component_id_to_swc_id, swc_id - ) + component_id = graph.component_id_from_swc_id(swc_id) nodes = graph.nodes_with_component_id(component_id) graph.remove_nodes(nodes, relabel_nodes=False) From 3e3059ffb8888b1bd85800ba469a77835431b87b Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 23 Apr 2026 02:44:07 +0000 Subject: [PATCH 3/3] refactor: optimize split train --- .../split_proofreading/split_datasets.py | 55 +++++++++++++------ .../split_feature_extraction.py | 36 ++++++------ src/neuron_proofreader/utils/img_util.py | 24 -------- 3 files changed, 56 insertions(+), 59 deletions(-) diff --git a/src/neuron_proofreader/split_proofreading/split_datasets.py b/src/neuron_proofreader/split_proofreading/split_datasets.py index 7de935cb..92d72482 100644 --- a/src/neuron_proofreader/split_proofreading/split_datasets.py +++ b/src/neuron_proofreader/split_proofreading/split_datasets.py @@ -9,6 +9,8 @@ """ +from concurrent.futures import ThreadPoolExecutor +from queue import Queue from torch.utils.data import IterableDataset from tqdm import tqdm @@ -42,6 +44,7 @@ def __init__( config, gt_path=None, metadata_path=None, + prefetch=4, segmentation_path=None, soma_centroids=set(), ): @@ -70,6 +73,7 @@ def __init__( # Instance attributes self.config = config self.gt_path = gt_path + self.prefetch = prefetch self.transform = ImageTransforms() if config.ml.transform else False # Build graph @@ -174,6 +178,7 @@ def add_dataset( config, gt_path=None, metadata_path=None, + prefetch=4, segmentation_path=None, soma_centroids=list(), ): @@ -195,6 +200,8 @@ def add_dataset( metadata_path : str, optional Patch to JSON file containing metadata on block that fragments were extracted from. Default is None. + prefetch : int, optional + Number of batches to prefetch. Default is 4. segmentation_path : str, optional Path to the segmentation that fragments were generated from. Default is None. @@ -208,6 +215,7 @@ def add_dataset( config, gt_path=gt_path, metadata_path=metadata_path, + prefetch=prefetch, segmentation_path=segmentation_path, soma_centroids=soma_centroids, ) @@ -223,21 +231,39 @@ def __iter__(self): targets : torch.Tensor Ground truth labels. """ + # Initializations samplers = self.init_samplers() - while len(samplers) > 0: - key = self.get_next_key(samplers) - try: - # Extract features - subgraph = next(samplers[key]) + queue = Queue(maxsize=self.prefetch * len(samplers)) + active_keys = set(samplers.keys()) + + # Launch one prefetch thread per dataset + with ThreadPoolExecutor(max_workers=len(samplers)) as executor: + for key, sampler in samplers.items(): + executor.submit(self._worker, key, sampler, queue) + + # Consume from queue until all datasets exhausted + while active_keys: + key, inputs, targets = queue.get() + if inputs is StopIteration: + active_keys.discard(key) + continue + if isinstance(inputs, Exception): + raise inputs + yield inputs, targets + + def _worker(self, key, sampler, queue): + """ + Runs in a background thread, prefetches extracted features into queue. + """ + try: + for subgraph in sampler: features = self.datasets[key].feature_extractor(subgraph) data = HeteroGraphData(features) - - # Get training inputs - inputs = data.get_inputs() - targets = data.get_targets() - yield inputs, targets - except StopIteration: - del samplers[key] + queue.put((key, data.get_inputs(), data.get_targets())) + except Exception as e: + queue.put((key, e, None)) + finally: + queue.put((key, StopIteration, None)) def generate_proposals(self, search_radius): """ @@ -248,16 +274,11 @@ def generate_proposals(self, search_radius): search_radius : float Search radius used to generate proposals. """ - # Proposal generation for key in tqdm(self.datasets, desc="Generate Proposals"): self.datasets[key].graph.generate_proposals( search_radius, allow_nonleaf_proposals=True ) - # Report results - print("# Proposals:", self.n_proposals()) - print("% Accepts:", self.p_accepts()) - # --- Helpers --- def __len__(self): """ diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index 4651151c..e3f75c15 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -356,23 +356,28 @@ def read_segmentation(self, center, shape): def compute_crop(self, proposal): """ - Extracts an intensity profile along a set of voxel coordinates. + Computes the center and cubic shape of the image patch for a proposal. + + Parameters + ---------- + proposal : Frozenset[int] + Proposal to compute image crop of. Returns ------- - profile : numpy.ndarray - Image with shape (2, H, W, D) containing a raw image and proposal - mask channels. + center : Tuple[int] + Center of the bounding box between the two proposal nodes. + shape : Tuple[int] + Cubic shape large enough to contain both nodes with padding. """ - # Get info - node1, node2 = tuple(proposal) - voxel1 = self.graph.node_voxel(node1) - voxel2 = self.graph.node_voxel(node2) + # Node info + node1, node2 = proposal + voxel1 = np.array(self.graph.node_voxel(node1)) + voxel2 = np.array(self.graph.node_voxel(node2)) # Compute bounds - bounds = img_util.get_minimal_bbox([voxel1, voxel2], self.padding) - center = tuple([int((v1 + v2) / 2) for v1, v2 in zip(voxel1, voxel2)]) - length = np.max([u - l for u, l in zip(bounds["max"], bounds["min"])]) + center = tuple(((voxel1 + voxel2) / 2).astype(int)) + length = np.max(np.abs(voxel2 - voxel1)) + 2 * self.padding return center, (length, length, length) @@ -455,11 +460,6 @@ def get_intensity_profile(self): profile = np.concatenate( (branch1_profile, proposal_profile, branch2_profile) ) - - # Adjust intensities - max_val = np.max(profile) + 1e-5 - self.img = np.minimum(max_val, self.img) / (max_val + 1e-5) - profile /= (max_val + 1e-5) return profile def get_branch_profile(self, node): @@ -496,8 +496,8 @@ def _extract_profile(self, voxels): Image with shape (2, H, W, D) containing a raw image and proposal mask channels. """ - voxels = check_list_length(voxels, min_length=16) - profile = np.array([self.img[tuple(voxel)] for voxel in voxels]) + voxels = np.asarray(check_list_length(voxels, min_length=16)) + profile = self.img[voxels[:, 0], voxels[:, 1], voxels[:, 2]] profile = np.append(profile, [profile.mean(), profile.std()]) return profile diff --git a/src/neuron_proofreader/utils/img_util.py b/src/neuron_proofreader/utils/img_util.py index ee0ee090..ab5d50e9 100644 --- a/src/neuron_proofreader/utils/img_util.py +++ b/src/neuron_proofreader/utils/img_util.py @@ -443,30 +443,6 @@ def get_contained_voxels(voxels, shape, buffer=0): return [v for v in voxels if is_contained(v, shape, buffer)] -def get_minimal_bbox(voxels, buffer=0): - """ - Gets the min and max coordinates of a bounding box that contains "voxels". - - Parameters - ---------- - voxels : numpy.ndarray - Array containing voxel coordinates. - buffer : int, optional - Constant value added/subtracted from the max/min coordinates of the - bounding box. Default is 0. - - Returns - ------- - bbox : Dict[str, numpy.ndarray] - Bounding box. - """ - bbox = { - "min": np.floor(np.min(voxels, axis=0) - buffer).astype(int), - "max": np.ceil(np.max(voxels, axis=0) + buffer).astype(int), - } - return bbox - - def get_neighbors(voxel, shape): """ Gets the neighbors of a given voxel coordinate.