Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions src/neuron_proofreader/split_proofreading/split_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

"""

from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from torch.utils.data import IterableDataset
from tqdm import tqdm

Expand Down Expand Up @@ -42,6 +44,7 @@ def __init__(
config,
gt_path=None,
metadata_path=None,
prefetch=4,
soma_centroids=set(),
):
"""
Expand All @@ -66,6 +69,7 @@ def __init__(
# Instance attributes
self.config = config
self.gt_path = gt_path
self.prefetch = prefetch
self.transform = ImageTransforms() if config.ml.transform else False

# Build graph
Expand Down Expand Up @@ -172,6 +176,7 @@ def add_dataset(
config,
gt_path=None,
metadata_path=None,
prefetch=4,
soma_centroids=list(),
):
"""
Expand All @@ -192,6 +197,8 @@ def add_dataset(
metadata_path : str, optional
Patch to JSON file containing metadata on block that fragments
were extracted from. Default is None.
prefetch : int, optional
Number of batches to prefetch. Default is 4.
soma_centroids : List[Tuple[int]], optional
Phyiscal coordinates of soma centroids. Default is an empty list.
"""
Expand All @@ -202,6 +209,7 @@ def add_dataset(
config,
gt_path=gt_path,
metadata_path=metadata_path,
prefetch=prefetch,
soma_centroids=soma_centroids,
)

Expand All @@ -216,21 +224,39 @@ def __iter__(self):
targets : torch.Tensor
Ground truth labels.
"""
# Initializations
samplers = self.init_samplers()
while len(samplers) > 0:
key = self.get_next_key(samplers)
try:
# Extract features
subgraph = next(samplers[key])
queue = Queue(maxsize=self.prefetch * len(samplers))
active_keys = set(samplers.keys())

# Launch one prefetch thread per dataset
with ThreadPoolExecutor(max_workers=len(samplers)) as executor:
for key, sampler in samplers.items():
executor.submit(self._worker, key, sampler, queue)

# Consume from queue until all datasets exhausted
while active_keys:
key, inputs, targets = queue.get()
if inputs is StopIteration:
active_keys.discard(key)
continue
if isinstance(inputs, Exception):
raise inputs
yield inputs, targets

def _worker(self, key, sampler, queue):
"""
Runs in a background thread, prefetches extracted features into queue.
"""
try:
for subgraph in sampler:
features = self.datasets[key].feature_extractor(subgraph)
data = HeteroGraphData(features)

# Get training inputs
inputs = data.get_inputs()
targets = data.get_targets()
yield inputs, targets
except StopIteration:
del samplers[key]
queue.put((key, data.get_inputs(), data.get_targets()))
except Exception as e:
queue.put((key, e, None))
finally:
queue.put((key, StopIteration, None))

def generate_proposals(self, search_radius):
"""
Expand All @@ -241,16 +267,11 @@ def generate_proposals(self, search_radius):
search_radius : float
Search radius used to generate proposals.
"""
# Proposal generation
for key in tqdm(self.datasets, desc="Generate Proposals"):
self.datasets[key].graph.generate_proposals(
search_radius, allow_nonleaf_proposals=True
)

# Report results
print("# Proposals:", self.n_proposals())
print("% Accepts:", self.p_accepts())

# --- Helpers ---
def __len__(self):
"""
Expand Down Expand Up @@ -326,7 +347,7 @@ def p_accepts(self):
accepts_cnt = 0
for dataset in self.datasets.values():
accepts_cnt += len(dataset.graph.gt_accepts)
return accepts_cnt / (self.n_proposals() + 1e-5)
return 100 * accepts_cnt / (self.n_proposals() + 1e-5)

def save_examples_summary(self, path):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,19 @@ def read_image(self, center, shape):

def compute_crop(self, proposal):
"""
Extracts an intensity profile along a set of voxel coordinates.
Computes the center and cubic shape of the image patch for a proposal.

Parameters
----------
proposal : Frozenset[int]
Proposal to compute image crop of.

Returns
-------
profile : numpy.ndarray
Image with shape (2, H, W, D) containing a raw image and proposal
mask channels.
center : Tuple[int]
Center of the bounding box between the two proposal nodes.
shape : Tuple[int]
Cubic shape large enough to contain both nodes with padding.
"""
# Node info
node1, node2 = proposal
Expand Down Expand Up @@ -481,8 +487,8 @@ def _extract_profile(self, voxels):
Image with shape (2, H, W, D) containing a raw image and proposal
mask channels.
"""
voxels = check_list_length(voxels, min_length=16)
profile = np.array([self.img[tuple(voxel)] for voxel in voxels])
voxels = np.asarray(check_list_length(voxels, min_length=16))
profile = self.img[voxels[:, 0], voxels[:, 1], voxels[:, 2]]
profile = np.append(profile, [profile.mean(), profile.std()])
return profile

Expand Down
Loading