Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc_template/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Configuration file for the Sphinx documentation builder."""

#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
Expand Down
1 change: 1 addition & 0 deletions src/neuron_proofreader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Init package"""

__version__ = "0.0.0"
2 changes: 1 addition & 1 deletion src/neuron_proofreader/machine_learning/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self):
RandomFlip3D(),
RandomRotation3D(),
RandomNoise3D(),
RandomContrast3D()
RandomContrast3D(),
]

def __call__(self, patches):
Expand Down
20 changes: 15 additions & 5 deletions src/neuron_proofreader/machine_learning/exaspim_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,7 @@ def sample_bright_voxel(self, brain_id):
pending = dict()
for _ in range(self.prefetch_foreground_sampling):
voxel = self.sample_interior_voxel(brain_id)
thread = executor.submit(
self.read_image, brain_id, voxel
)
thread = executor.submit(self.read_image, brain_id, voxel)
pending[thread] = voxel

# Check if image patch is bright enough
Expand Down Expand Up @@ -489,8 +487,20 @@ def _load_batch(self, start_idx):
)

# Process results
img_patches = np.zeros((batch_size, 1,) + self.patch_shape)
mask_patches = np.zeros((batch_size, 1,) + self.patch_shape)
img_patches = np.zeros(
(
batch_size,
1,
)
+ self.patch_shape
)
mask_patches = np.zeros(
(
batch_size,
1,
)
+ self.patch_shape
)
for i, process in enumerate(as_completed(processes)):
img, mask = process.result()
img_patches[i, 0, ...] = img
Expand Down
23 changes: 14 additions & 9 deletions src/neuron_proofreader/machine_learning/geometric_gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def __init__(self, ggnn_name, output_dim=64):
# Set geometric gnn
if ggnn_name == "egnn":
self.geometric_gnn = EGNN(
in_node_dim=1,
hidden_dim=32,
out_node_dim=output_dim
in_node_dim=1, hidden_dim=32, out_node_dim=output_dim
)

# --- Core Routines ---
Expand All @@ -97,7 +95,9 @@ def forward(self, h, x, edge_index, batch):
)

# Pool node embeddings
h_g, x_g, edge_index_g = self.pool_nonbranching_paths(h_g, x_g, edge_index_g)
h_g, x_g, edge_index_g = self.pool_nonbranching_paths(
h_g, x_g, edge_index_g
)

# Encode pooled graph
h_g = self.encode_pooled_graph(h_g, x_g, edge_index_g)
Expand Down Expand Up @@ -158,7 +158,9 @@ def pool_nonbranching_paths(self, h, x, edge_index):
# Finish
h_pooled = torch.stack(h_pooled, dim=0)
x_pooled = torch.stack(x_pooled, dim=0)
edge_index_pooled = self.get_edge_index_pooled(edge_index, node_to_path)
edge_index_pooled = self.get_edge_index_pooled(
edge_index, node_to_path
)
return h_pooled, x_pooled, edge_index_pooled

def get_adj_and_deg(self, edge_index, num_nodes):
Expand Down Expand Up @@ -201,10 +203,13 @@ def extract_subgraph(self, h, x, edge_index, node_mask):
id_map = {int(n): i for i, n in enumerate(node_ids.tolist())}
edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
edge_index_g = edge_index[:, edge_mask]
edge_index_g = torch.stack([
torch.tensor([id_map[int(u)] for u in edge_index_g[0]]),
torch.tensor([id_map[int(v)] for v in edge_index_g[1]])
], dim=0)
edge_index_g = torch.stack(
[
torch.tensor([id_map[int(u)] for u in edge_index_g[0]]),
torch.tensor([id_map[int(v)] for v in edge_index_g[1]]),
],
dim=0,
)
return h_g, x_g, edge_index_g

@staticmethod
Expand Down
21 changes: 12 additions & 9 deletions src/neuron_proofreader/machine_learning/gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class VisionHGAT(torch.nn.Module):
Heterogeneous graph attention network that processes multimodal features
such as image patches and feature vectors.
"""

# Class attributes
relations = [
str(("branch", "to", "branch")),
Expand All @@ -47,7 +48,9 @@ def __init__(

# Initial embeddings
self.node_embedding = init_node_embedding(hidden_dim)
self.patch_embedding = init_patch_embedding(patch_shape, hidden_dim // 2)
self.patch_embedding = init_patch_embedding(
patch_shape, hidden_dim // 2
)

# Message passing layers
self.disable_msg_passing = disable_msg_passing
Expand All @@ -58,7 +61,7 @@ def __init__(
else:
self.gat1 = self.init_gat(hidden_dim, hidden_dim, heads)
self.gat2 = self.init_gat(hidden_dim * heads, hidden_dim, heads)
self.output = nn.Linear(hidden_dim * heads ** 2, 1)
self.output = nn.Linear(hidden_dim * heads**2, 1)

# Initialize weights
self.init_weights()
Expand All @@ -81,9 +84,7 @@ def init_mlp_layers(self, hidden_dim, n_layers=2):
for _ in range(n_layers):
layers.append(
nn_geometric.HeteroDictLinear(
hidden_dim,
hidden_dim,
types=("branch", "proposal")
hidden_dim, hidden_dim, types=("branch", "proposal")
)
)
return layers
Expand Down Expand Up @@ -160,10 +161,12 @@ def init_node_embedding(output_dim):
dim_p = node_input_dims["proposal"]

# Set node embedding layer
node_embedding = nn.ModuleDict({
"branch": FeedForwardNet(dim_b, output_dim, 3),
"proposal": FeedForwardNet(dim_p, output_dim // 2, 3),
})
node_embedding = nn.ModuleDict(
{
"branch": FeedForwardNet(dim_b, output_dim, 3),
"proposal": FeedForwardNet(dim_p, output_dim // 2, 3),
}
)
return node_embedding


Expand Down
8 changes: 4 additions & 4 deletions src/neuron_proofreader/machine_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
lr=1e-3,
max_epochs=200,
min_recall=0,
save_mistake_mips=False
save_mistake_mips=False,
):
"""
Instantiates a Trainer object.
Expand Down Expand Up @@ -292,7 +292,7 @@ def compute_stats(y, hat_y):
"f1": avg_f1,
"precision": avg_prec,
"recall": avg_recall,
"accuracy": avg_acc
"accuracy": avg_acc,
}
return stats

Expand Down Expand Up @@ -426,7 +426,7 @@ def __init__(
device="cuda",
lr=1e-3,
max_epochs=200,
save_mistake_mips=False
save_mistake_mips=False,
):
"""
Instantiates a DistributedTrainer object.
Expand All @@ -452,7 +452,7 @@ def __init__(
device=device,
lr=lr,
max_epochs=max_epochs,
save_mistake_mips=save_mistake_mips
save_mistake_mips=save_mistake_mips,
)

# Check that multiple GPUs are available
Expand Down
2 changes: 1 addition & 1 deletion src/neuron_proofreader/machine_learning/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(self, checkpoint_path, model_config):
checkpoint_path=checkpoint_path,
model_config=model_config,
task_head_config="binary_classifier",
freeze_encoder=True
freeze_encoder=True,
)

# Instance attributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def load_fragments(dataset, is_test=False):
sub_df = merge_sites_df.loc[merge_sites_df["brain_id"] == brain_id]
for segmentation_id in sub_df["segmentation_id"].unique():
if (brain_id, segmentation_id) in target_pairs:
swc_pointer = f"{root}/{brain_id}/{segmentation_id}/merged_fragments.zip"
swc_pointer = (
f"{root}/{brain_id}/{segmentation_id}/merged_fragments.zip"
)
dataset.load_fragment_graphs(
brain_id, swc_pointer, use_anisotropy=False
)
Expand Down
24 changes: 11 additions & 13 deletions src/neuron_proofreader/merge_proofreading/merge_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
subgraph_to_point_cloud,
)
from neuron_proofreader.merge_proofreading.merge_dataloading import (
get_brain_merge_sites
get_brain_merge_sites,
)
from neuron_proofreader.skeleton_graph import SkeletonGraph
from neuron_proofreader.utils import (
Expand Down Expand Up @@ -71,6 +71,7 @@ class MergeSiteDataset(Dataset):
patch_shape : Tuple[int], optional
Shape of the 3D image patches to extract.
"""

random_negative_example_prob = 0.8

def __init__(
Expand Down Expand Up @@ -121,9 +122,7 @@ def __init__(
self.merge_site_kdtrees = dict()

# --- Load Data ---
def load_fragment_graphs(
self, brain_id, swc_pointer, use_anisotropy=True
):
def load_fragment_graphs(self, brain_id, swc_pointer, use_anisotropy=True):
"""
Loads fragments containing merge mistakes for a whole-brain dataset,
then stores them in the "graphs" attribute.
Expand All @@ -139,7 +138,7 @@ def load_fragment_graphs(
graph = SkeletonGraph(
anisotropy=self.anisotropy,
node_spacing=self.node_spacing,
use_anisotropy=use_anisotropy
use_anisotropy=use_anisotropy,
)
graph.load(swc_pointer)

Expand Down Expand Up @@ -771,6 +770,7 @@ def generate_negative_examples(self):
negative_examples : List[dict]
List of negative examples collected across all graphs.
"""

# Subroutines
def add_examples():
"""
Expand Down Expand Up @@ -924,7 +924,7 @@ def __init__(
is_multimodal=False,
modality=None,
sampler=None,
use_shuffle=True
use_shuffle=True,
):
"""
Instantiates a MergeSiteDataLoader object.
Expand Down Expand Up @@ -970,11 +970,11 @@ def __iter__(self):
for start in range(0, len(idxs), self.batch_size):
end = min(start + self.batch_size, len(idxs))
if self.is_multimodal and self.modality == "graph":
yield self._load_image_graph_batch(idxs[start: end])
yield self._load_image_graph_batch(idxs[start:end])
elif self.is_multimodal and self.modality == "pointcloud":
yield self._load_image_pc_batch(idxs[start: end])
yield self._load_image_pc_batch(idxs[start:end])
else:
yield self._load_image_batch(idxs[start: end])
yield self._load_image_batch(idxs[start:end])

def _load_image_batch(self, batch_idxs):
"""
Expand Down Expand Up @@ -1084,9 +1084,7 @@ def _load_image_graph_batch(self, idxs):
h.append(h_i)
x.append(x_i)
edge_index.append(edge_index_i)
batches.append(
torch.full((n_i,), i, dtype=torch.long)
)
batches.append(torch.full((n_i,), i, dtype=torch.long))

node_offset += n_i

Expand All @@ -1100,7 +1098,7 @@ def _load_image_graph_batch(self, idxs):
batch = ml_util.TensorDict(
{
"img": ml_util.to_tensor(patches),
"graph": (h, x, edge_index, batches)
"graph": (h, x, edge_index, batches),
}
)
return batch, ml_util.to_tensor(targets)
Expand Down
30 changes: 22 additions & 8 deletions src/neuron_proofreader/merge_proofreading/merge_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __init__(
prefetch=64,
segmentation_path=None,
subgraph_radius=100,
use_new_mask=False
use_new_mask=False,
):
# Call parent class
super().__init__()
Expand Down Expand Up @@ -317,7 +317,9 @@ def find_fragments_to_search(self):
for nodes in nx.connected_components(self.graph):
# Compute path length
node = util.sample_once(list(nodes))
length = self.graph.cable_length(max_depth=self.min_size, root=node)
length = self.graph.cable_length(
max_depth=self.min_size, root=node
)

# Check if path length satisfies threshold
if length > self.min_size:
Expand Down Expand Up @@ -431,7 +433,7 @@ def __init__(
segmentation_path=None,
step_size=10,
subgraph_radius=100,
use_new_mask=False
use_new_mask=False,
):
# Call parent class
super().__init__(
Expand All @@ -445,7 +447,7 @@ def __init__(
prefetch=prefetch,
segmentation_path=segmentation_path,
subgraph_radius=subgraph_radius,
use_new_mask=use_new_mask
use_new_mask=use_new_mask,
)

# Instance attributes
Expand Down Expand Up @@ -537,7 +539,13 @@ def _get_batch(self, nodes, img, offset):
patch_centers = self.get_patch_centers(nodes) - offset

# Populate batch array
batch = np.empty((len(nodes), 2,) + self.patch_shape)
batch = np.empty(
(
len(nodes),
2,
)
+ self.patch_shape
)
for i, center in enumerate(patch_centers):
s = img_util.get_slices(center, self.patch_shape)
batch[i, 0, ...] = img_util.normalize(img[s])
Expand All @@ -550,7 +558,13 @@ def _get_multimodal_batch(self, nodes, img, offset):
patch_centers = self.get_patch_centers(nodes) - offset

# Populate batch array
patches = np.empty((len(nodes), 2,) + self.patch_shape)
patches = np.empty(
(
len(nodes),
2,
)
+ self.patch_shape
)
point_clouds = np.empty((len(nodes), 3, 3600), dtype=np.float32)
for i, (node, center) in enumerate(zip(nodes, patch_centers)):
s = img_util.get_slices(center, self.patch_shape)
Expand Down Expand Up @@ -610,7 +624,7 @@ def __init__(
prefetch=128,
segmentation_path=None,
subgraph_radius=100,
use_new_mask=False
use_new_mask=False,
):
# Call parent class
super().__init__(
Expand All @@ -623,7 +637,7 @@ def __init__(
prefetch=prefetch,
segmentation_path=segmentation_path,
subgraph_radius=subgraph_radius,
use_new_mask=use_new_mask
use_new_mask=use_new_mask,
)

# Instance attributes
Expand Down
Loading
Loading