diff --git a/src/neuron_proofreader/split_proofreading/split_datasets.py b/src/neuron_proofreader/split_proofreading/split_datasets.py index 650e9d8..6b7c0c6 100644 --- a/src/neuron_proofreader/split_proofreading/split_datasets.py +++ b/src/neuron_proofreader/split_proofreading/split_datasets.py @@ -9,6 +9,8 @@ """ +from concurrent.futures import ThreadPoolExecutor +from queue import Queue from torch.utils.data import IterableDataset from tqdm import tqdm @@ -42,6 +44,7 @@ def __init__( config, gt_path=None, metadata_path=None, + prefetch=4, soma_centroids=set(), ): """ @@ -66,6 +69,7 @@ def __init__( # Instance attributes self.config = config self.gt_path = gt_path + self.prefetch = prefetch self.transform = ImageTransforms() if config.ml.transform else False # Build graph @@ -172,6 +176,7 @@ def add_dataset( config, gt_path=None, metadata_path=None, + prefetch=4, soma_centroids=list(), ): """ @@ -192,6 +197,8 @@ def add_dataset( metadata_path : str, optional Patch to JSON file containing metadata on block that fragments were extracted from. Default is None. + prefetch : int, optional + Number of batches to prefetch. Default is 4. soma_centroids : List[Tuple[int]], optional Phyiscal coordinates of soma centroids. Default is an empty list. """ @@ -202,6 +209,7 @@ def add_dataset( config, gt_path=gt_path, metadata_path=metadata_path, + prefetch=prefetch, soma_centroids=soma_centroids, ) @@ -216,21 +224,39 @@ def __iter__(self): targets : torch.Tensor Ground truth labels. """ + # Initializations samplers = self.init_samplers() - while len(samplers) > 0: - key = self.get_next_key(samplers) - try: - # Extract features - subgraph = next(samplers[key]) + queue = Queue(maxsize=self.prefetch * len(samplers)) + active_keys = set(samplers.keys()) + + # Launch one prefetch thread per dataset + with ThreadPoolExecutor(max_workers=len(samplers)) as executor: + for key, sampler in samplers.items(): + executor.submit(self._worker, key, sampler, queue) + + # Consume from queue until all datasets exhausted + while active_keys: + key, inputs, targets = queue.get() + if inputs is StopIteration: + active_keys.discard(key) + continue + if isinstance(inputs, Exception): + raise inputs + yield inputs, targets + + def _worker(self, key, sampler, queue): + """ + Runs in a background thread, prefetches extracted features into queue. + """ + try: + for subgraph in sampler: features = self.datasets[key].feature_extractor(subgraph) data = HeteroGraphData(features) - - # Get training inputs - inputs = data.get_inputs() - targets = data.get_targets() - yield inputs, targets - except StopIteration: - del samplers[key] + queue.put((key, data.get_inputs(), data.get_targets())) + except Exception as e: + queue.put((key, e, None)) + finally: + queue.put((key, StopIteration, None)) def generate_proposals(self, search_radius): """ @@ -241,16 +267,11 @@ def generate_proposals(self, search_radius): search_radius : float Search radius used to generate proposals. """ - # Proposal generation for key in tqdm(self.datasets, desc="Generate Proposals"): self.datasets[key].graph.generate_proposals( search_radius, allow_nonleaf_proposals=True ) - # Report results - print("# Proposals:", self.n_proposals()) - print("% Accepts:", self.p_accepts()) - # --- Helpers --- def __len__(self): """ @@ -326,7 +347,7 @@ def p_accepts(self): accepts_cnt = 0 for dataset in self.datasets.values(): accepts_cnt += len(dataset.graph.gt_accepts) - return accepts_cnt / (self.n_proposals() + 1e-5) + return 100 * accepts_cnt / (self.n_proposals() + 1e-5) def save_examples_summary(self, path): """ diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index 63bab67..918468a 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -347,13 +347,19 @@ def read_image(self, center, shape): def compute_crop(self, proposal): """ - Extracts an intensity profile along a set of voxel coordinates. + Computes the center and cubic shape of the image patch for a proposal. + + Parameters + ---------- + proposal : Frozenset[int] + Proposal to compute image crop of. Returns ------- - profile : numpy.ndarray - Image with shape (2, H, W, D) containing a raw image and proposal - mask channels. + center : Tuple[int] + Center of the bounding box between the two proposal nodes. + shape : Tuple[int] + Cubic shape large enough to contain both nodes with padding. """ # Node info node1, node2 = proposal @@ -481,8 +487,8 @@ def _extract_profile(self, voxels): Image with shape (2, H, W, D) containing a raw image and proposal mask channels. """ - voxels = check_list_length(voxels, min_length=16) - profile = np.array([self.img[tuple(voxel)] for voxel in voxels]) + voxels = np.asarray(check_list_length(voxels, min_length=16)) + profile = self.img[voxels[:, 0], voxels[:, 1], voxels[:, 2]] profile = np.append(profile, [profile.mean(), profile.std()]) return profile