From 96d05bd4506bc47c279533cd3e4e98f373e5d900 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 11 Jun 2026 19:10:08 +0000 Subject: [PATCH 01/10] feat: initial version for path transformer --- .../geometric_learning/curve_augmentation.py | 191 ++++++++++ .../geometric_learning/path_transformer.py | 329 ++++++++++++++++++ .../geometric_learning/skel_datamodules.py | 277 +++++++++++++++ ...{image_augmentation.py => augmentation.py} | 41 ++- .../machine_learning/image_dataloader.py | 2 +- src/neuron_proofreader/skeleton_graph.py | 39 ++- 6 files changed, 858 insertions(+), 21 deletions(-) create mode 100644 src/neuron_proofreader/geometric_learning/curve_augmentation.py create mode 100644 src/neuron_proofreader/geometric_learning/path_transformer.py create mode 100644 src/neuron_proofreader/geometric_learning/skel_datamodules.py rename src/neuron_proofreader/machine_learning/{image_augmentation.py => augmentation.py} (83%) diff --git a/src/neuron_proofreader/geometric_learning/curve_augmentation.py b/src/neuron_proofreader/geometric_learning/curve_augmentation.py new file mode 100644 index 0000000..bc79576 --- /dev/null +++ b/src/neuron_proofreader/geometric_learning/curve_augmentation.py @@ -0,0 +1,191 @@ +""" +Created on Wed June 11 12:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code for applying data augmentation to 3D space curves. + +""" + +import numpy as np +import random + + +class CurveTransforms: + """ + Class that applies a sequence of transforms to a 3D space curve. + """ + def __init__(self): + """ + Initializes a CurveTransforms instance that applies augmentation to + a 3D space curve. + """ + self.transforms = [ + RandomRotation3D(), RandomMirror3D(), RandomJitter3D(), + ] + + def __call__(self, curve): + """ + Applies transforms to the input curve. + + Parameters + ---------- + curve : numpy.ndarray + Array of shape (N, 3) representing N points in 3D space (x, y, z). + """ + # Check whether to reverse path + if random.random() > 0.5: + curve = np.flip(curve) + + # Apply transforms + for transform in self.transforms: + curve = transform(curve) + return curve + + +# --- Noise Transforms --- +class RandomJitter3D: + """ + Randomly adds Gaussian noise to each point in a 3D curve. + """ + def __init__(self, sigma=1, p=0.5): + """ + Initializes a RandomJitter3D transformer. + + Parameters + ---------- + sigma : float, optional + Standard deviation of the Gaussian noise. Default is 0.01. + p : float, optional + Probability of applying the transform. Default is 0.5. + """ + self.sigma = sigma + self.p = p + + def __call__(self, curve): + """ + Applies random jitter to the input curve. + + Parameters + ---------- + curve : numpy.ndarray + Array of shape (N, 3) representing N points in 3D space. + + Returns + ------- + numpy.ndarray + Jittered curve of shape (N, 3). + """ + if random.random() > self.p: + return curve + noise = np.random.normal(0, self.sigma, size=curve.shape) + return curve + noise + + +# --- Geometric Transforms --- +class RandomRotation3D: + """ + Applies a random 3D rotation to a curve about a random axis. + """ + def __init__(self, max_angle=np.pi, p=0.5): + """ + Initializes a RandomRotation3D transformer. + + Parameters + ---------- + max_angle : float, optional + Maximum rotation angle in radians. Default is pi (full rotation). + p : float, optional + Probability of applying the transform. Default is 0.5. + """ + self.max_angle = max_angle + self.p = p + + def _rotation_matrix(self, axis, angle): + """ + Computes the Rodrigues rotation matrix for a given axis and angle. + + Parameters + ---------- + axis : numpy.ndarray + Unit vector of shape (3,) representing the rotation axis. + angle : float + Rotation angle in radians. + + Returns + ------- + numpy.ndarray + Rotation matrix of shape (3, 3). + """ + c, s = np.cos(angle), np.sin(angle) + x, y, z = axis + return np.array([ + [c + x*x*(1-c), x*y*(1-c) - z*s, x*z*(1-c) + y*s], + [y*x*(1-c) + z*s, c + y*y*(1-c), y*z*(1-c) - x*s], + [z*x*(1-c) - y*s, z*y*(1-c) + x*s, c + z*z*(1-c)], + ]) + + def __call__(self, curve): + """ + Applies a random rotation to the input curve about its centroid. + + Parameters + ---------- + curve : numpy.ndarray + Array of shape (N, 3) representing N points in 3D space. + + Returns + ------- + numpy.ndarray + Rotated curve of shape (N, 3). + """ + if random.random() > self.p: + return curve + axis = np.random.randn(3) + axis /= np.linalg.norm(axis) + angle = random.uniform(-self.max_angle, self.max_angle) + R = self._rotation_matrix(axis, angle) + centroid = curve.mean(axis=0) + return (curve - centroid) @ R.T + centroid + + +class RandomMirror3D: + """ + Randomly mirrors a 3D curve along one or more axes about its centroid. + """ + def __init__(self, axes=(0, 1, 2), p=0.5): + """ + Initializes a RandomMirror3D transformer. + + Parameters + ---------- + axes : Tuple[int], optional + Axes to consider for mirroring. Default is (0, 1, 2). + p : float, optional + Per-axis probability of mirroring. Default is 0.5. + """ + self.axes = axes + self.p = p + + def __call__(self, curve): + """ + Applies random mirroring to the input curve. + + Parameters + ---------- + curve : numpy.ndarray + Array of shape (N, 3) representing N points in 3D space. + + Returns + ------- + numpy.ndarray + Mirrored curve of shape (N, 3). + """ + curve = curve.copy() + centroid = curve.mean(axis=0) + for axis in self.axes: + if random.random() > self.p: + continue + curve[:, axis] = 2 * centroid[axis] - curve[:, axis] + return curve diff --git a/src/neuron_proofreader/geometric_learning/path_transformer.py b/src/neuron_proofreader/geometric_learning/path_transformer.py new file mode 100644 index 0000000..b8e1d0b --- /dev/null +++ b/src/neuron_proofreader/geometric_learning/path_transformer.py @@ -0,0 +1,329 @@ +""" +Created on Wed June 10 12:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +... + +""" + +import numpy as np +import torch +import torch.nn as nn + + +class CurveEncoder(nn.Module): + """ + Transformer encoder that maps a 3D space curve, normalized to the unit + sphere, to a fixed-size latent vector. The curve is tokenized into fixed- + length segments with a fixed learned token for the start point at the + origin and a projected token for the end point. Positional encodings are + sinusoidal over the normalized arc position [0, 1], making the encoder + robust to varying numbers of points and path lengths. + """ + + def __init__( + self, + segment_len=10, + d_token=64, + n_heads=4, + n_layers=4, + d_ff=64, + latent_dim=32, + dropout=0.1, + ): + """ + Parameters + ---------- + segment_len : int + Number of points per segment token. + d_token : int + Dimension of each token. + n_heads : int + Number of attention heads. + n_layers : int + Number of transformer encoder layers. + d_ff : int + Feed-forward hidden dimension. + latent_dim : int + Dimension of the output latent vector. + dropout : float + Dropout probability. + """ + # Call parent class + super().__init__() + + # Instance attributes + self.segment_len = segment_len + self.start_token = nn.Parameter(torch.randn(1, 1, d_token)) + self.end_token_proj = nn.Linear(3, d_token) + self.segment_proj = nn.Linear(segment_len * 3, d_token) + + # Archictecture + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_token, + nhead=n_heads, + dim_feedforward=d_ff, + dropout=dropout, + batch_first=True, + ) + self.transformer = nn.TransformerEncoder( + encoder_layer, num_layers=n_layers + ) + self.to_latent = nn.Sequential( + nn.LayerNorm(d_token), + nn.Linear(d_token, latent_dim), + ) + + def forward(self, curve, token_mask=None): + """ + Parameters + ---------- + curve : torch.Tensor + Shape (B, N, 3), normalized to the unit sphere, with + curve[:, 0] == [0, 0, 0]. N can vary across calls. + mask : torch.Tensor, optional + Shape (B, N), True where padding (to be ignored). Default is None. + + Returns + ------- + z : torch.Tensor + Latent vector of shape (B, latent_dim). + tokens : torch.Tensor + Per-token encodings of shape (B, n_segments + 2, d_token). + """ + B, N, _ = curve.shape + n_segments = N // self.segment_len + + # Start and end tokens + start_tok = self.start_token.expand(B, -1, -1) # (B, 1, d_token) + end_tok = self.end_token_proj(curve[:, -1, :]).unsqueeze(1) # (B, 1, d_token) + + # Segment tokens + segments = curve[:, :n_segments * self.segment_len, :] + segments = segments.reshape(B, n_segments, self.segment_len * 3) + seg_tokens = self.segment_proj(segments) # (B, n_seg, d_token) + + # Concatenate: [start | segments | end] + tokens = torch.cat([start_tok, seg_tokens, end_tok], dim=1) # (B, n_seg+2, d_token) + + # Sinusoidal positional encoding over arc position + pe = sinusoidal_encoding(tokens.shape[1], tokens.shape[2], tokens.device) + if token_mask is not None: + pe = pe * (~token_mask).unsqueeze(-1).float() # Zero out pe for padding + tokens = tokens + pe + + # Convert point-level mask to token-level mask + token_mask = None + if token_mask is not None: + seg_mask = token_mask[:, ::self.segment_len][:, :n_segments] # (B, n_seg) + token_mask = torch.cat([ + torch.zeros(B, 1, dtype=torch.bool, device=token_mask.device), + seg_mask, + torch.zeros(B, 1, dtype=torch.bool, device=token_mask.device), + ], dim=1) # (B, n_seg+2) + + tokens = self.transformer(tokens, src_key_padding_mask=token_mask) + + # Mean pool over non-padding tokens only + if token_mask is not None: + valid = (~token_mask).unsqueeze(-1).float() + z = self.to_latent((tokens * valid).sum(dim=1) / valid.sum(dim=1)) + else: + z = self.to_latent(tokens.mean(dim=1)) + return z, tokens + + +class CurveDecoder(nn.Module): + """ + Transformer decoder that reconstructs a 3D space curve from a latent + vector and the encoder's token representations. Positional queries are + sinusoidally encoded over arc position and biased by the global latent. + The output resolution can differ from the encoder input, allowing + decoding at arbitrary granularity. + """ + def __init__( + self, + n_points=100, + segment_len=10, + d_token=128, + n_heads=4, + n_layers=4, + d_ff=512, + latent_dim=64, + dropout=0.1, + ): + """ + Parameters + ---------- + n_points : int + Default number of output curve points. + segment_len : int + Number of points per segment token (must match encoder). + d_token : int + Dimension of each token throughout the transformer. + n_heads : int + Number of attention heads. + n_layers : int + Number of transformer decoder layers. + d_ff : int + Feed-forward hidden dimension. + latent_dim : int + Dimension of the input latent vector. + dropout : float + Dropout probability. + """ + super().__init__() + self.segment_len = segment_len + self.n_segments = n_points // segment_len + + # Project latent to d_token to bias the positional queries + self.latent_proj = nn.Linear(latent_dim, d_token) + + decoder_layer = nn.TransformerDecoderLayer( + d_model=d_token, + nhead=n_heads, + dim_feedforward=d_ff, + dropout=dropout, + batch_first=True, + ) + self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=n_layers) + + self.to_points = nn.Sequential( + nn.LayerNorm(d_token), + nn.Linear(d_token, segment_len * 3), + ) + + def forward(self, z, encoder_tokens, encoder_mask=None, n_segments=None): + """ + Parameters + ---------- + z : torch.Tensor + Latent vector of shape (B, latent_dim). + encoder_tokens : torch.Tensor + Per-token encoder outputs of shape (B, n_segments + 2, d_token). + encoder_mask : torch.Tensor, optional + Shape (B, n_segments + 2), True where padding. Passed as + memory_key_padding_mask to cross-attention. Default is None. + n_segments : int, optional + Number of output segments. Inferred from encoder tokens if not + provided. + + Returns + ------- + curve : torch.Tensor + Reconstructed curve of shape (B, n_segments * segment_len, 3). + """ + B = z.shape[0] + d_token = encoder_tokens.shape[2] + n_segments = n_segments or self.n_segments + + # Sinusoidal queries over arc position, biased by global latent + pe = sinusoidal_encoding(n_segments, d_token, encoder_tokens.device) + latent = self.latent_proj(z).unsqueeze(1) # (B, 1, d_token) + queries = pe.expand(B, -1, -1) + latent # (B, n_seg, d_token) + + out = self.transformer( + queries, + encoder_tokens, + memory_key_padding_mask=encoder_mask, + ) # (B, n_seg, d_token) + + segments = self.to_points(out) # (B, n_seg, seg_len*3) + curve = segments.reshape(B, n_segments * self.segment_len, 3) + return curve + + +class CurveAutoencoder(nn.Module): + """ + ARBOR: Autoencoder for Representing Branching and Ordered curves in 3D. + + Encodes a uniformly sampled, ordered 3D curve normalized to the unit + sphere to a latent vector, and decodes it back to a curve of the same + length. Robust to varying input lengths via dynamic sinusoidal positional + encoding over the normalized arc position. + """ + def __init__( + self, + n_points=100, + segment_len=10, + d_token=128, + n_heads=4, + n_layers=4, + d_ff=512, + latent_dim=64, + dropout=0.1, + ): + super().__init__() + self.encoder = CurveEncoder( + segment_len=segment_len, + d_token=d_token, + n_heads=n_heads, + n_layers=n_layers, + d_ff=d_ff, + latent_dim=latent_dim, + dropout=dropout, + ) + self.decoder = CurveDecoder( + n_points=n_points, + segment_len=segment_len, + d_token=d_token, + n_heads=n_heads, + n_layers=n_layers, + d_ff=d_ff, + latent_dim=latent_dim, + dropout=dropout, + ) + + def forward(self, curve): + """ + Parameters + ---------- + curve : torch.Tensor + Shape (B, N, 3), normalized to the unit sphere, curve[:, 0] == 0. + + Returns + ------- + reconstruction : torch.Tensor + Shape (B, N, 3). + z : torch.Tensor + Latent vector of shape (B, latent_dim). + """ + z, encoder_tokens = self.encoder(curve) + n_segments = (curve.shape[1] // self.decoder.segment_len) + reconstruction = self.decoder(z, encoder_tokens, n_segments=n_segments) + return reconstruction, z + + def encode(self, curve): + z, _ = self.encoder(curve) + return z + + +# --- Helpers --- +def sinusoidal_encoding(n_tokens, d_token, device): + """ + Sinusoidal positional encoding over normalised arc position [0, 1]. + + Parameters + ---------- + n_tokens : int + Number of tokens (segments + 2 endpoint tokens). + d_token : int + Model dimension. + device : torch.device + Device to create the encoding on. + + Returns + ------- + torch.Tensor + Encoding of shape (1, n_tokens, d_token). + """ + position = torch.linspace(0, 1, n_tokens, device=device).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_token, 2, device=device) * (-np.log(10000.0) / d_token) + ) + pe = torch.zeros(n_tokens, d_token, device=device) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + return pe.unsqueeze(0) diff --git a/src/neuron_proofreader/geometric_learning/skel_datamodules.py b/src/neuron_proofreader/geometric_learning/skel_datamodules.py new file mode 100644 index 0000000..2ab7319 --- /dev/null +++ b/src/neuron_proofreader/geometric_learning/skel_datamodules.py @@ -0,0 +1,277 @@ +""" +Created on Mon June 8 17:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +... + +""" + +from random import shuffle +from torch.utils.data import Dataset, DataLoader, Sampler + +import networkx as nx +import numpy as np +import torch + +from neuron_proofreader.skeleton_graph import SkeletonGraph +from neuron_proofreader.utils import util + + +# --- Dataset Classes --- +class PathsDataset(Dataset): + + def __init__( + self, + brain_id, + swcs_path, + bin_width=400, + graph_config=None, + max_path_length=10e4, + transform=None, + ): + # Instance attributes + self.brain_id = brain_id + self.bin_width = bin_width + self.max_path_length = max_path_length + + # Core data structures + self.graph = self.load_skeletons(graph_config, swcs_path) + self.paths = self.irreducible_paths() + self.transform = transform + + self.set_bins() + + def load_skeletons(self, config, swcs_path): + graph = SkeletonGraph( + anisotropy=config.anisotropy, + min_cable_length=config.min_cable_length, + node_spacing=config.node_spacing, + use_anisotropy=config.use_anisotropy, + verbose=config.verbose, + ) + graph.load(swcs_path) + return graph + + def set_bins(self): + # Initialize bins + db = self.bin_width + n_bins = int(np.ceil(self.max_path_length / db)) + self.bins = {(i * db, (i + 1) * db): [] for i in range(n_bins)} + + # Store path indices in bins + for idx, p in enumerate(self.paths): + i = min(int(self.path_length(p) / db), n_bins - 1) + self.bins[(i * db, (i + 1) * db)].append(idx) + + # --- Get Examples --- + def __getitem__(self, bin_id): + # Get path + idx = util.sample_once(self.bins[bin_id]) + path = self.paths[idx].copy() + + # Check whether to subsample + if self.path_length(path) > self.max_path_length: + new_length = np.random.uniform(*bin_id) + node = util.sample_once(path) + path = self.path_thru_node(node, max_depth=new_length) + + # Check whether to transform + curve = self.node_xyz[path] + if self.transform: + curve = self.transform(curve) + + # Normalize + curve = (curve - curve[0]) / self.max_path_length + return curve + + # --- Helpers --- + def __getattr__(self, name): + return getattr(self.graph, name) + + def __len__(self): + return len(self.paths) + + def __repr__(self): + lengths = [self.path_length(p) for p in self.paths] + num_neurons = nx.number_connected_components(self.graph) + return ( + f"BrainDataset(" + f"\n brain_id={self.brain_id}, " + f"\n num_neurons={num_neurons}, " + f"\n num_paths={len(self)}, " + f"\n min_path_length={np.min(lengths):.2f}, " + f"\n mean_path_length={np.mean(lengths):.2f}, " + f"\n max_path_length={np.max(lengths):.2f}," + f"\n)" + ) + + +class PathsDatasetCollection(Dataset): + + def __init__(self, datasets, examples_per_bin=32): + """ + Collection of PathsDataset instances, one per brain, with a unified + bin structure for sampling uniformly across brains and path lengths. + + Parameters + ---------- + datasets : List[PathsDataset] + List of PathsDataset instances, one per brain. + """ + self.datasets = datasets + self.examples_per_bin = examples_per_bin + self.set_bins() + + def set_bins(self): + """ + Builds a unified bin structure across all datasets. Each bin key is a + (lower, upper) tuple and maps to a list of (dataset_idx, path_idx) + pairs. + """ + # Collect all bin keys across datasets + all_keys = set() + for ds in self.datasets: + all_keys.update(ds.bins.keys()) + + self.bins = {k: [] for k in sorted(all_keys)} + for ds_idx, ds in enumerate(self.datasets): + for bin_id, path_indices in ds.bins.items(): + for path_idx in path_indices: + self.bins[bin_id].append((ds_idx, path_idx)) + + def __getitem__(self, bin_id): + """ + Samples a random (dataset, path) pair from the given bin. + + Parameters + ---------- + bin_id : Tuple[float, float] + The (lower, upper) bin key. + + Returns + ------- + numpy.ndarray + Normalized curve of shape (N, 3). + """ + ds_idx, path_idx = util.sample_once(self.bins[bin_id]) + return self.datasets[ds_idx][bin_id] + + # --- Helpers --- + def all_path_lengths(self): + path_lengths = list() + for dataset in self.datasets: + path_lengths.extend( + [dataset.path_length(p) for p in dataset.paths] + ) + return np.array(path_lengths) + + def generate_bin_ids(self): + nonempty_keys = [k for k, v in self.bins.items() if v] + bin_ids = list(nonempty_keys) * self.examples_per_bin + shuffle(bin_ids) + return bin_ids + + def __len__(self): + return sum(len(ds) for ds in self.datasets) + + def __repr__(self): + n_brains = len(self.datasets) + n_paths = len(self) + non_empty = sum(1 for v in self.bins.values() if len(v) > 0) + return ( + f"PathsDatasetCollection(" + f"num_brains={n_brains}, " + f"num_paths={n_paths}, " + f"num_bins={non_empty})" + ) + + +# --- DataLoader Classes --- +class BinSampler(Sampler): + """ + Sampler that yields bin IDs uniformly across all non-empty bins, + with a fixed number of examples per bin per epoch. + """ + def __init__(self, dataset, examples_per_bin=10): + """ + Parameters + ---------- + dataset : PathsDatasetCollection + Dataset to sample from. + examples_per_bin : int, optional + Number of examples to draw from each bin per epoch. Default is 10. + """ + self.examples_per_bin = examples_per_bin + self.non_empty_bins = [k for k, v in dataset.bins.items() if len(v) > 0] + + def __iter__(self): + bin_ids = self.non_empty_bins * self.examples_per_bin + shuffle(bin_ids) + return iter(bin_ids) + + def __len__(self): + return len(self.non_empty_bins) * self.examples_per_bin + + +def collate_curves(curves): + """ + Pads a list of curves to the longest in the batch and generates an + attention mask. + + Parameters + ---------- + curves : List[numpy.ndarray] + Each of shape (N_i, 3), where N_i can vary. + + Returns + ------- + padded : torch.Tensor + Shape (B, N_max, 3), zero-padded. + mask : torch.Tensor + Shape (B, N_max), True where padding. + """ + lengths = [len(c) for c in curves] + n_max = max(lengths) + B = len(curves) + + padded = torch.zeros(B, n_max, 3) + mask = torch.ones(B, n_max, dtype=torch.bool) + for i, (c, l) in enumerate(zip(curves, lengths)): + padded[i, :l] = torch.tensor(c) + mask[i, :l] = False + return padded, mask + + +def build_dataloader( + dataset, batch_size=32, examples_per_bin=32, num_workers=0 +): + """ + Builds a DataLoader for a PathsDatasetCollection that samples uniformly + across all non-empty bins. + + Parameters + ---------- + dataset : PathsDatasetCollection + Dataset to load from. + examples_per_bin : int, optional + Number of examples per bin per epoch. Default is 10. + batch_size : int, optional + Number of curves per batch. Default is 32. + num_workers : int, optional + Number of worker processes for data loading. Default is 4. + + Returns + ------- + DataLoader + """ + sampler = BinSampler(dataset, examples_per_bin=examples_per_bin) + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=collate_curves, + num_workers=num_workers, + pin_memory=True, + sampler=sampler, + ) diff --git a/src/neuron_proofreader/machine_learning/image_augmentation.py b/src/neuron_proofreader/machine_learning/augmentation.py similarity index 83% rename from src/neuron_proofreader/machine_learning/image_augmentation.py rename to src/neuron_proofreader/machine_learning/augmentation.py index 93b2d72..5cb974a 100644 --- a/src/neuron_proofreader/machine_learning/image_augmentation.py +++ b/src/neuron_proofreader/machine_learning/augmentation.py @@ -14,6 +14,7 @@ import random +# --- Image Augmentation --- class ImageTransforms: """ Class that applies a sequence of transforms to a 3D image and segmentation @@ -42,7 +43,7 @@ def __call__(self, patches): ---------- patches : numpy.ndarray Image with the shape (2, H, W, D), where the first channel is the - input image and second is the segmentation. + raw image and second is a mask. """ for transform in self.transforms: patches = transform(patches) @@ -73,8 +74,8 @@ def __call__(self, patches): Parameters ---------- patches : numpy.ndarray - Image with the shape (2, H, W, D), where "patches[0, ...]" is from - the input image and "patches[1, ...]" is from the segmentation. + Image with the shape (2, H, W, D), where the first channel is the + raw image and second is a mask. """ for axis in self.axes: if random.random() > 0.5: @@ -95,7 +96,7 @@ def __init__(self, angles=(-90, 90), axes=((0, 1), (0, 2), (1, 2))): Parameters ---------- angles : Tuple[int], optional - Maximum angle of rotation. Default is (-45, 45). + Maximum angle of rotation. Default is (-90, 90). axis : Tuple[Tuple[int]], optional Axes to apply rotation. Default is ((0, 1), (0, 2), (1, 2)) """ @@ -109,8 +110,8 @@ def __call__(self, patches): Parameters ---------- patches : numpy.ndarray - Image with the shape (2, H, W, D), where "patches[0, ...]" is from - the input image and "patches[1, ...]" is from the segmentation. + Image with the shape (2, H, W, D), where the first channel is the + raw image and second is a mask. """ for axes in self.axes: if random.random() < 0.5: @@ -120,7 +121,7 @@ def __call__(self, patches): return patches @staticmethod - def rotate3d(img_patch, angle, axes, is_segmentation=False): + def rotate3d(img_patch, angle, axes, is_mask=False): """ Rotates a 3D image patch around the specified axes by a given angle. @@ -132,13 +133,12 @@ def rotate3d(img_patch, angle, axes, is_segmentation=False): Angle (in degrees) by which to rotate the image patch around the specified axes. axes : Tuple[int] - Tuple representing the two axes of rotation. - is_segmentation : bool, optional - Indication of whether the image is a segmentation. Default is - False. + Two axes of rotation. + is_mask : bool, optional + True if the image is a mask. """ - order = 0 if is_segmentation else 3 - multipler = 4 if is_segmentation else 1 + order = 0 if is_mask else 3 + multipler = 4 if is_mask else 1 img_patch = rotate( multipler * img_patch, angle, @@ -174,8 +174,8 @@ def __call__(self, patches): Parameters ---------- patches : numpy.ndarray - Image with the shape (2, H, W, D), where "patches[0, ...]" is from - the input image and "patches[1, ...]" is from the segmentation. + Image with the shape (2, H, W, D), where the first channel is the + raw image and second is a mask. Returns ------- @@ -226,8 +226,8 @@ def __call__(self, patches): Parameters ---------- patches : numpy.ndarray - Image with the shape (2, H, W, D), where the zeroth channel is - from the raw image and first channel is from the segmentation. + Image with the shape (2, H, W, D), where the first channel is the + raw image and second is a mask. """ lo = np.percentile(patches[0], np.random.uniform(*self.p_low)) hi = np.percentile(patches[0], np.random.uniform(*self.p_high)) @@ -260,10 +260,13 @@ def __call__(self, patches): Parameters ---------- patches : numpy.ndarray - Image with the shape (2, H, W, D), where "patches[0, ...]" is from - the input image and "patches[1, ...]" is from the segmentation. + Image with the shape (2, H, W, D), where the first channel is the + raw image and second is a mask. """ std = self.max_std * random.random() patches[0] += np.random.uniform(-std, std, patches[0].shape) patches[0] = np.clip(patches[0], 0, 1) return patches + + +# --- 3D Space Curve Augmentation --- diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index a0c1834..bd9c8c5 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -14,7 +14,7 @@ import tensorstore as ts from neuron_proofreader.configs import ImageConfig -from neuron_proofreader.machine_learning.image_augmentation import ( +from neuron_proofreader.machine_learning.augmentation import ( ImageTransforms, ) from neuron_proofreader.utils import geometry_util, img_util, util diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index ca47684..cc4833e 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -853,6 +853,43 @@ def irreducible_nodes(self): """ return {i for i in map(int, self.nodes) if self.degree[i] != 2} + def irreducible_paths(self): + """ + Extracts non-branching paths between irreducible nodes (degree 1 or >= 3). + + Returns + ------- + paths : List[numpy.ndarray] + Each entry is an ordered list of node IDs forming a path between two + irreducible nodes, inclusive of both endpoints. + """ + # Initializations + irreducible = {n for n in self.nodes if self.degree(n) != 2} + paths = [] + visited_edges = set() + + # Search + for source in irreducible: + for nb in self.neighbors(source): + edge = frozenset((source, nb)) + if edge in visited_edges: + continue + visited_edges.add(edge) + + # Walk along degree-2 chain until next irreducible node + path = [source, nb] + prev, curr = source, nb + while curr not in irreducible: + nxt = next(n for n in self.neighbors(curr) if n != prev) + edge = frozenset((curr, nxt)) + visited_edges.add(edge) + path.append(nxt) + prev, curr = curr, nxt + + paths.append(np.array(path, dtype=int)) + + return paths + def leaf_nodes(self): """ Gets all leaf nodes in the graph. @@ -1053,7 +1090,7 @@ def path_length(self, path): """ if len(path) > 1: diffs = self.node_xyz[path[1:]] - self.node_xyz[path[:-1]] - return np.sqrt(np.sum(diffs**2)) + return np.sum(np.sqrt(np.sum(diffs**2, axis=1))) else: return 0 From 4e1169af17085bdc1cf752ae8e0d79828b6ef20f Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 11 Jun 2026 19:16:21 +0000 Subject: [PATCH 02/10] refactor: updated start token --- src/neuron_proofreader/geometric_learning/path_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neuron_proofreader/geometric_learning/path_transformer.py b/src/neuron_proofreader/geometric_learning/path_transformer.py index b8e1d0b..fba9326 100644 --- a/src/neuron_proofreader/geometric_learning/path_transformer.py +++ b/src/neuron_proofreader/geometric_learning/path_transformer.py @@ -56,7 +56,7 @@ def __init__( # Instance attributes self.segment_len = segment_len - self.start_token = nn.Parameter(torch.randn(1, 1, d_token)) + self.start_token = nn.Linear(3, d_token) self.end_token_proj = nn.Linear(3, d_token) self.segment_proj = nn.Linear(segment_len * 3, d_token) From 0fc63b4593b9eb107961d7d4b4bcb54cf0cf176d Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sat, 13 Jun 2026 00:52:05 +0000 Subject: [PATCH 03/10] feat: curve encoder training working --- .../geometric_learning/curve_augmentation.py | 34 ++- .../geometric_learning/curve_datamodules.py | 267 +++++++++++++++++ ...th_transformer.py => curve_transformer.py} | 125 +++++--- .../geometric_learning/skel_datamodules.py | 277 ------------------ .../machine_learning/image_dataloader.py | 2 +- .../machine_learning/train.py | 4 +- .../merge_proofreading/merge_datamodules.py | 44 +-- src/neuron_proofreader/skeleton_graph.py | 4 +- src/neuron_proofreader/utils/graph_util.py | 4 +- src/neuron_proofreader/utils/swc_util.py | 4 +- src/neuron_proofreader/utils/util.py | 2 +- 11 files changed, 413 insertions(+), 354 deletions(-) create mode 100644 src/neuron_proofreader/geometric_learning/curve_datamodules.py rename src/neuron_proofreader/geometric_learning/{path_transformer.py => curve_transformer.py} (75%) delete mode 100644 src/neuron_proofreader/geometric_learning/skel_datamodules.py diff --git a/src/neuron_proofreader/geometric_learning/curve_augmentation.py b/src/neuron_proofreader/geometric_learning/curve_augmentation.py index bc79576..e569a95 100644 --- a/src/neuron_proofreader/geometric_learning/curve_augmentation.py +++ b/src/neuron_proofreader/geometric_learning/curve_augmentation.py @@ -16,13 +16,16 @@ class CurveTransforms: """ Class that applies a sequence of transforms to a 3D space curve. """ + def __init__(self): """ Initializes a CurveTransforms instance that applies augmentation to a 3D space curve. """ self.transforms = [ - RandomRotation3D(), RandomMirror3D(), RandomJitter3D(), + RandomRotation3D(), + RandomMirror3D(), + RandomJitter3D(), ] def __call__(self, curve): @@ -49,7 +52,8 @@ class RandomJitter3D: """ Randomly adds Gaussian noise to each point in a 3D curve. """ - def __init__(self, sigma=1, p=0.5): + + def __init__(self, sigma=0.1, p=0.5): """ Initializes a RandomJitter3D transformer. @@ -88,6 +92,7 @@ class RandomRotation3D: """ Applies a random 3D rotation to a curve about a random axis. """ + def __init__(self, max_angle=np.pi, p=0.5): """ Initializes a RandomRotation3D transformer. @@ -120,11 +125,25 @@ def _rotation_matrix(self, axis, angle): """ c, s = np.cos(angle), np.sin(angle) x, y, z = axis - return np.array([ - [c + x*x*(1-c), x*y*(1-c) - z*s, x*z*(1-c) + y*s], - [y*x*(1-c) + z*s, c + y*y*(1-c), y*z*(1-c) - x*s], - [z*x*(1-c) - y*s, z*y*(1-c) + x*s, c + z*z*(1-c)], - ]) + return np.array( + [ + [ + c + x * x * (1 - c), + x * y * (1 - c) - z * s, + x * z * (1 - c) + y * s, + ], + [ + y * x * (1 - c) + z * s, + c + y * y * (1 - c), + y * z * (1 - c) - x * s, + ], + [ + z * x * (1 - c) - y * s, + z * y * (1 - c) + x * s, + c + z * z * (1 - c), + ], + ] + ) def __call__(self, curve): """ @@ -154,6 +173,7 @@ class RandomMirror3D: """ Randomly mirrors a 3D curve along one or more axes about its centroid. """ + def __init__(self, axes=(0, 1, 2), p=0.5): """ Initializes a RandomMirror3D transformer. diff --git a/src/neuron_proofreader/geometric_learning/curve_datamodules.py b/src/neuron_proofreader/geometric_learning/curve_datamodules.py new file mode 100644 index 0000000..68aa29d --- /dev/null +++ b/src/neuron_proofreader/geometric_learning/curve_datamodules.py @@ -0,0 +1,267 @@ +""" +Created on Mon June 8 17:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +... + +""" + +from torch.utils.data import Dataset, DataLoader, Sampler + +import networkx as nx +import numpy as np +import pandas as pd +import torch + +from neuron_proofreader.skeleton_graph import SkeletonGraph +from neuron_proofreader.utils import util + + +# --- Dataset Classes --- +class PathsDataset(Dataset): + + def __init__( + self, + brain_id, + swcs_path, + graph_config=None, + max_length=np.inf, + transform=None, + ): + # Instance attributes + self.brain_id = brain_id + self.max_length = max_length + self.transform = transform + + # Core data structures + self.graph = self.load_skeletons(graph_config, swcs_path) + self.paths = self.irreducible_paths() + + + def load_skeletons(self, config, swcs_path): + graph = SkeletonGraph( + anisotropy=config.anisotropy, + min_cable_length=config.min_cable_length, + node_spacing=config.node_spacing, + use_anisotropy=config.use_anisotropy, + verbose=config.verbose, + ) + graph.load(swcs_path) + return graph + + # --- Get Examples --- + def __getitem__(self, i): + # Get path + path = self.paths[i].copy() + + # Check whether to subsample + if self.path_length(path) > self.max_length: + root = util.sample_once(path) + new_length = np.random.random() * self.max_length + path = self.path_thru_node(root, max_depth=new_length) + + # Check whether to transform + curve = self.node_xyz[path] + if self.transform: + curve = self.transform(curve) + + # Normalize + curve -= curve[0] + curve[1:] -= curve[:-1] + return curve + + # --- Helpers --- + def path_lengths(self): + return np.array([self.path_length(p) for p in self.paths]) + + def __getattr__(self, name): + return getattr(self.graph, name) + + def __len__(self): + return len(self.paths) + + def __repr__(self): + lengths = self.path_lengths() + num_neurons = nx.number_connected_components(self.graph) + return ( + f"BrainDataset(" + f"\n brain_id={self.brain_id}, " + f"\n num_neurons={num_neurons}, " + f"\n num_paths={len(self)}, " + f"\n min_path_length={np.min(lengths):.2f}, " + f"\n mean_path_length={np.mean(lengths):.2f}, " + f"\n max_length={np.max(lengths):.2f}," + f"\n)" + ) + + +class PathsDatasetCollection(Dataset): + + def __init__(self, datasets, is_val=False, n_val_examples=1000, seed=42): + """ + Parameters + ---------- + datasets : List[PathsDataset] + List of PathsDataset instances, one per brain. + is_val : bool, optional + If True, precomputes a fixed set of examples at construction time. + Default is False. + n_val_examples : int, optional + Number of fixed validation examples to precompute. Default is 1000. + seed : int, optional + Random seed for reproducible val set. Default is 42. + """ + # Instance attributes + self.datasets = datasets + self.is_val = is_val + self.set_examples_df() + + # Check whether to set validation examples + if is_val: + self.val_examples = self.set_val_examples(n_val_examples, seed) + + def set_examples_df(self): + rows = [] + for ds_idx, dataset in enumerate(self.datasets): + ds_idxs = np.full(len(dataset), ds_idx) + p_idxs = np.arange(len(dataset)) + ds_lengths = dataset.path_lengths() + rows.append( + pd.DataFrame( + { + "ds_idx": ds_idxs, + "path_idx": p_idxs, + "length": ds_lengths, + } + ) + ) + self.examples_df = pd.concat(rows, ignore_index=True) + + def set_val_examples(self, n, seed): + """ + Samples n examples with fixed seed, strips transforms, and caches + the resulting examples. + """ + rng = np.random.default_rng(seed) + indices = rng.choice(len(self.examples_df), size=n, replace=False) + examples = [] + for i in indices: + ds_idx = self.examples_df["ds_idx"][i] + path_idx = self.examples_df["path_idx"][i] + dataset = self.datasets[ds_idx] + examples.append(dataset[path_idx]) + return examples + + # --- Data Fetching --- + def __getitem__(self, i): + # Case 1: validation example + if self.is_val: + return self.val_examples[i] + + # Case 2: train example + ds_idx = self.examples_df["ds_idx"][i] + path_idx = self.examples_df["path_idx"][i] + return self.datasets[ds_idx][path_idx] + + def __len__(self): + if self.is_val: + return len(self.val_examples) + return len(self.examples_df) + + def __repr__(self): + return ( + f"PathsDatasetCollection(" + f"num_brains={len(self.datasets)}, " + f"num_paths={len(self.examples_df)}) " + ) + + +# --- DataLoader Classes --- +class PathSampler(Sampler): + + def __init__(self, dataset, examples_per_epoch): + """ + Parameters + ---------- + dataset : PathsDatasetCollection + Dataset to sample from. + """ + self.dataset = dataset + self.examples_per_epoch = examples_per_epoch + + def __iter__(self): + idxs = self.dataset.examples_df.sample( + self.examples_per_epoch, weights="length" + ).index + return iter(np.array(idxs)) + + def __len__(self): + return self.examples_per_epoch + + +def collate_curves(curves): + """ + Pads a list of curves to the longest in the batch and generates an + attention mask. + + Parameters + ---------- + curves : List[numpy.ndarray] + Each of shape (N_i, 3), where N_i can vary. + + Returns + ------- + padded : torch.Tensor + Shape (B, N_max, 3), zero-padded. + mask : torch.Tensor + Shape (B, N_max), True where padding. + """ + lengths = [len(c) for c in curves] + n_max = max(lengths) + B = len(curves) + + padded = torch.zeros(B, n_max, 3) + mask = torch.ones(B, n_max, dtype=torch.bool) + for i, (c, l) in enumerate(zip(curves, lengths)): + padded[i, :l] = torch.tensor(c) + mask[i, :l] = False + return padded, mask + + +def build_dataloader( + dataset, + batch_size=32, + examples_per_epoch=5000, + num_workers=0, + use_sampler=True, +): + """ + Builds a DataLoader for a PathsDatasetCollection that samples uniformly + across all non-empty bins. + + Parameters + ---------- + dataset : PathsDatasetCollection + Dataset to load from. + examples_per_epoch : int, optional + Number of examples per epoch. Default is 5000. + batch_size : int, optional + Number of curves per batch. Default is 32. + num_workers : int, optional + Number of worker processes for data loading. Default is 0. + + Returns + ------- + DataLoader + """ + sampler = PathSampler(dataset, examples_per_epoch) if use_sampler else None + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=collate_curves, + num_workers=num_workers, + pin_memory=True, + sampler=sampler, + ) diff --git a/src/neuron_proofreader/geometric_learning/path_transformer.py b/src/neuron_proofreader/geometric_learning/curve_transformer.py similarity index 75% rename from src/neuron_proofreader/geometric_learning/path_transformer.py rename to src/neuron_proofreader/geometric_learning/curve_transformer.py index fba9326..45cc0da 100644 --- a/src/neuron_proofreader/geometric_learning/path_transformer.py +++ b/src/neuron_proofreader/geometric_learning/curve_transformer.py @@ -12,6 +12,8 @@ import torch import torch.nn as nn +from neuron_proofreader.utils import util + class CurveEncoder(nn.Module): """ @@ -56,7 +58,7 @@ def __init__( # Instance attributes self.segment_len = segment_len - self.start_token = nn.Linear(3, d_token) + self.start_token = nn.Parameter(torch.randn(1, 1, d_token)) self.end_token_proj = nn.Linear(3, d_token) self.segment_proj = nn.Linear(segment_len * 3, d_token) @@ -76,13 +78,13 @@ def __init__( nn.Linear(d_token, latent_dim), ) - def forward(self, curve, token_mask=None): + def forward(self, offsets, mask=None): """ Parameters ---------- - curve : torch.Tensor + offsets : torch.Tensor Shape (B, N, 3), normalized to the unit sphere, with - curve[:, 0] == [0, 0, 0]. N can vary across calls. + offsets[:, 0] == [0, 0, 0]. N can vary across calls. mask : torch.Tensor, optional Shape (B, N), True where padding (to be ignored). Default is None. @@ -93,36 +95,47 @@ def forward(self, curve, token_mask=None): tokens : torch.Tensor Per-token encodings of shape (B, n_segments + 2, d_token). """ - B, N, _ = curve.shape + B, N, _ = offsets.shape n_segments = N // self.segment_len # Start and end tokens start_tok = self.start_token.expand(B, -1, -1) # (B, 1, d_token) - end_tok = self.end_token_proj(curve[:, -1, :]).unsqueeze(1) # (B, 1, d_token) + end_tok = self.end_token_proj(offsets[:, -1, :]).unsqueeze( + 1 + ) # (B, 1, d_token) # Segment tokens - segments = curve[:, :n_segments * self.segment_len, :] + segments = offsets[:, : n_segments * self.segment_len, :] segments = segments.reshape(B, n_segments, self.segment_len * 3) seg_tokens = self.segment_proj(segments) # (B, n_seg, d_token) # Concatenate: [start | segments | end] - tokens = torch.cat([start_tok, seg_tokens, end_tok], dim=1) # (B, n_seg+2, d_token) - - # Sinusoidal positional encoding over arc position - pe = sinusoidal_encoding(tokens.shape[1], tokens.shape[2], tokens.device) - if token_mask is not None: - pe = pe * (~token_mask).unsqueeze(-1).float() # Zero out pe for padding - tokens = tokens + pe + tokens = torch.cat( + [start_tok, seg_tokens, end_tok], dim=1 + ) # (B, n_seg+2, d_token) # Convert point-level mask to token-level mask token_mask = None + if mask is not None: + seg_mask = mask[:, :: self.segment_len][ + :, :n_segments + ] # (B, n_seg) + token_mask = torch.cat( + [ + torch.zeros(B, 1, dtype=torch.bool, device=mask.device), + seg_mask, + torch.zeros(B, 1, dtype=torch.bool, device=mask.device), + ], + dim=1, + ) # (B, n_seg+2) + + # Sinusoidal positional encoding, zeroed out for padding tokens + pe = sinusoidal_encoding( + tokens.shape[1], tokens.shape[2], tokens.device + ) if token_mask is not None: - seg_mask = token_mask[:, ::self.segment_len][:, :n_segments] # (B, n_seg) - token_mask = torch.cat([ - torch.zeros(B, 1, dtype=torch.bool, device=token_mask.device), - seg_mask, - torch.zeros(B, 1, dtype=torch.bool, device=token_mask.device), - ], dim=1) # (B, n_seg+2) + pe = pe * (~token_mask).unsqueeze(-1).float() + tokens = tokens + pe tokens = self.transformer(tokens, src_key_padding_mask=token_mask) @@ -132,6 +145,7 @@ def forward(self, curve, token_mask=None): z = self.to_latent((tokens * valid).sum(dim=1) / valid.sum(dim=1)) else: z = self.to_latent(tokens.mean(dim=1)) + return z, tokens @@ -143,15 +157,16 @@ class CurveDecoder(nn.Module): The output resolution can differ from the encoder input, allowing decoding at arbitrary granularity. """ + def __init__( self, n_points=100, segment_len=10, - d_token=128, + d_token=64, n_heads=4, n_layers=4, - d_ff=512, - latent_dim=64, + d_ff=64, + latent_dim=32, dropout=0.1, ): """ @@ -188,7 +203,9 @@ def __init__( dropout=dropout, batch_first=True, ) - self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=n_layers) + self.transformer = nn.TransformerDecoder( + decoder_layer, num_layers=n_layers + ) self.to_points = nn.Sequential( nn.LayerNorm(d_token), @@ -231,31 +248,39 @@ def forward(self, z, encoder_tokens, encoder_mask=None, n_segments=None): ) # (B, n_seg, d_token) segments = self.to_points(out) # (B, n_seg, seg_len*3) - curve = segments.reshape(B, n_segments * self.segment_len, 3) - return curve + offsets = segments.reshape(B, n_segments * self.segment_len, 3) + return offsets class CurveAutoencoder(nn.Module): - """ - ARBOR: Autoencoder for Representing Branching and Ordered curves in 3D. - Encodes a uniformly sampled, ordered 3D curve normalized to the unit - sphere to a latent vector, and decodes it back to a curve of the same - length. Robust to varying input lengths via dynamic sinusoidal positional - encoding over the normalized arc position. - """ def __init__( self, n_points=100, segment_len=10, - d_token=128, + d_token=64, n_heads=4, n_layers=4, - d_ff=512, - latent_dim=64, + d_ff=64, + latent_dim=32, dropout=0.1, ): + # Call parent class super().__init__() + + # Config + self.config = { + "n_points": n_points, + "segment_len": segment_len, + "d_token": d_token, + "n_heads": n_heads, + "n_layers": n_layers, + "d_ff": d_ff, + "latent_dim": latent_dim, + "dropout": dropout, + } + + # Architecture self.encoder = CurveEncoder( segment_len=segment_len, d_token=d_token, @@ -276,12 +301,12 @@ def __init__( dropout=dropout, ) - def forward(self, curve): + def forward(self, offsets, token_mask): """ Parameters ---------- - curve : torch.Tensor - Shape (B, N, 3), normalized to the unit sphere, curve[:, 0] == 0. + offsets : torch.Tensor + Shape (B, N, 3), normalized to the unit sphere, offsets[:, 0] == 0. Returns ------- @@ -290,15 +315,26 @@ def forward(self, curve): z : torch.Tensor Latent vector of shape (B, latent_dim). """ - z, encoder_tokens = self.encoder(curve) - n_segments = (curve.shape[1] // self.decoder.segment_len) + z, encoder_tokens = self.encoder(offsets, token_mask) + n_segments = offsets.shape[1] // self.decoder.segment_len reconstruction = self.decoder(z, encoder_tokens, n_segments=n_segments) return reconstruction, z - def encode(self, curve): - z, _ = self.encoder(curve) + def encode(self, offsets): + z, _ = self.encoder(offsets) return z + # --- Helpers --- + def save_config(self, path): + util.write_json(path, self.config) + + @classmethod + def load(cls, path): + checkpoint = torch.load(path) + model = cls(**checkpoint["config"]) + model.load_state_dict(checkpoint["model_state_dict"]) + return model + # --- Helpers --- def sinusoidal_encoding(n_tokens, d_token, device): @@ -321,7 +357,8 @@ def sinusoidal_encoding(n_tokens, d_token, device): """ position = torch.linspace(0, 1, n_tokens, device=device).unsqueeze(1) div_term = torch.exp( - torch.arange(0, d_token, 2, device=device) * (-np.log(10000.0) / d_token) + torch.arange(0, d_token, 2, device=device) + * (-np.log(10000.0) / d_token) ) pe = torch.zeros(n_tokens, d_token, device=device) pe[:, 0::2] = torch.sin(position * div_term) diff --git a/src/neuron_proofreader/geometric_learning/skel_datamodules.py b/src/neuron_proofreader/geometric_learning/skel_datamodules.py deleted file mode 100644 index 2ab7319..0000000 --- a/src/neuron_proofreader/geometric_learning/skel_datamodules.py +++ /dev/null @@ -1,277 +0,0 @@ -""" -Created on Mon June 8 17:00:00 2026 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -... - -""" - -from random import shuffle -from torch.utils.data import Dataset, DataLoader, Sampler - -import networkx as nx -import numpy as np -import torch - -from neuron_proofreader.skeleton_graph import SkeletonGraph -from neuron_proofreader.utils import util - - -# --- Dataset Classes --- -class PathsDataset(Dataset): - - def __init__( - self, - brain_id, - swcs_path, - bin_width=400, - graph_config=None, - max_path_length=10e4, - transform=None, - ): - # Instance attributes - self.brain_id = brain_id - self.bin_width = bin_width - self.max_path_length = max_path_length - - # Core data structures - self.graph = self.load_skeletons(graph_config, swcs_path) - self.paths = self.irreducible_paths() - self.transform = transform - - self.set_bins() - - def load_skeletons(self, config, swcs_path): - graph = SkeletonGraph( - anisotropy=config.anisotropy, - min_cable_length=config.min_cable_length, - node_spacing=config.node_spacing, - use_anisotropy=config.use_anisotropy, - verbose=config.verbose, - ) - graph.load(swcs_path) - return graph - - def set_bins(self): - # Initialize bins - db = self.bin_width - n_bins = int(np.ceil(self.max_path_length / db)) - self.bins = {(i * db, (i + 1) * db): [] for i in range(n_bins)} - - # Store path indices in bins - for idx, p in enumerate(self.paths): - i = min(int(self.path_length(p) / db), n_bins - 1) - self.bins[(i * db, (i + 1) * db)].append(idx) - - # --- Get Examples --- - def __getitem__(self, bin_id): - # Get path - idx = util.sample_once(self.bins[bin_id]) - path = self.paths[idx].copy() - - # Check whether to subsample - if self.path_length(path) > self.max_path_length: - new_length = np.random.uniform(*bin_id) - node = util.sample_once(path) - path = self.path_thru_node(node, max_depth=new_length) - - # Check whether to transform - curve = self.node_xyz[path] - if self.transform: - curve = self.transform(curve) - - # Normalize - curve = (curve - curve[0]) / self.max_path_length - return curve - - # --- Helpers --- - def __getattr__(self, name): - return getattr(self.graph, name) - - def __len__(self): - return len(self.paths) - - def __repr__(self): - lengths = [self.path_length(p) for p in self.paths] - num_neurons = nx.number_connected_components(self.graph) - return ( - f"BrainDataset(" - f"\n brain_id={self.brain_id}, " - f"\n num_neurons={num_neurons}, " - f"\n num_paths={len(self)}, " - f"\n min_path_length={np.min(lengths):.2f}, " - f"\n mean_path_length={np.mean(lengths):.2f}, " - f"\n max_path_length={np.max(lengths):.2f}," - f"\n)" - ) - - -class PathsDatasetCollection(Dataset): - - def __init__(self, datasets, examples_per_bin=32): - """ - Collection of PathsDataset instances, one per brain, with a unified - bin structure for sampling uniformly across brains and path lengths. - - Parameters - ---------- - datasets : List[PathsDataset] - List of PathsDataset instances, one per brain. - """ - self.datasets = datasets - self.examples_per_bin = examples_per_bin - self.set_bins() - - def set_bins(self): - """ - Builds a unified bin structure across all datasets. Each bin key is a - (lower, upper) tuple and maps to a list of (dataset_idx, path_idx) - pairs. - """ - # Collect all bin keys across datasets - all_keys = set() - for ds in self.datasets: - all_keys.update(ds.bins.keys()) - - self.bins = {k: [] for k in sorted(all_keys)} - for ds_idx, ds in enumerate(self.datasets): - for bin_id, path_indices in ds.bins.items(): - for path_idx in path_indices: - self.bins[bin_id].append((ds_idx, path_idx)) - - def __getitem__(self, bin_id): - """ - Samples a random (dataset, path) pair from the given bin. - - Parameters - ---------- - bin_id : Tuple[float, float] - The (lower, upper) bin key. - - Returns - ------- - numpy.ndarray - Normalized curve of shape (N, 3). - """ - ds_idx, path_idx = util.sample_once(self.bins[bin_id]) - return self.datasets[ds_idx][bin_id] - - # --- Helpers --- - def all_path_lengths(self): - path_lengths = list() - for dataset in self.datasets: - path_lengths.extend( - [dataset.path_length(p) for p in dataset.paths] - ) - return np.array(path_lengths) - - def generate_bin_ids(self): - nonempty_keys = [k for k, v in self.bins.items() if v] - bin_ids = list(nonempty_keys) * self.examples_per_bin - shuffle(bin_ids) - return bin_ids - - def __len__(self): - return sum(len(ds) for ds in self.datasets) - - def __repr__(self): - n_brains = len(self.datasets) - n_paths = len(self) - non_empty = sum(1 for v in self.bins.values() if len(v) > 0) - return ( - f"PathsDatasetCollection(" - f"num_brains={n_brains}, " - f"num_paths={n_paths}, " - f"num_bins={non_empty})" - ) - - -# --- DataLoader Classes --- -class BinSampler(Sampler): - """ - Sampler that yields bin IDs uniformly across all non-empty bins, - with a fixed number of examples per bin per epoch. - """ - def __init__(self, dataset, examples_per_bin=10): - """ - Parameters - ---------- - dataset : PathsDatasetCollection - Dataset to sample from. - examples_per_bin : int, optional - Number of examples to draw from each bin per epoch. Default is 10. - """ - self.examples_per_bin = examples_per_bin - self.non_empty_bins = [k for k, v in dataset.bins.items() if len(v) > 0] - - def __iter__(self): - bin_ids = self.non_empty_bins * self.examples_per_bin - shuffle(bin_ids) - return iter(bin_ids) - - def __len__(self): - return len(self.non_empty_bins) * self.examples_per_bin - - -def collate_curves(curves): - """ - Pads a list of curves to the longest in the batch and generates an - attention mask. - - Parameters - ---------- - curves : List[numpy.ndarray] - Each of shape (N_i, 3), where N_i can vary. - - Returns - ------- - padded : torch.Tensor - Shape (B, N_max, 3), zero-padded. - mask : torch.Tensor - Shape (B, N_max), True where padding. - """ - lengths = [len(c) for c in curves] - n_max = max(lengths) - B = len(curves) - - padded = torch.zeros(B, n_max, 3) - mask = torch.ones(B, n_max, dtype=torch.bool) - for i, (c, l) in enumerate(zip(curves, lengths)): - padded[i, :l] = torch.tensor(c) - mask[i, :l] = False - return padded, mask - - -def build_dataloader( - dataset, batch_size=32, examples_per_bin=32, num_workers=0 -): - """ - Builds a DataLoader for a PathsDatasetCollection that samples uniformly - across all non-empty bins. - - Parameters - ---------- - dataset : PathsDatasetCollection - Dataset to load from. - examples_per_bin : int, optional - Number of examples per bin per epoch. Default is 10. - batch_size : int, optional - Number of curves per batch. Default is 32. - num_workers : int, optional - Number of worker processes for data loading. Default is 4. - - Returns - ------- - DataLoader - """ - sampler = BinSampler(dataset, examples_per_bin=examples_per_bin) - return DataLoader( - dataset, - batch_size=batch_size, - collate_fn=collate_curves, - num_workers=num_workers, - pin_memory=True, - sampler=sampler, - ) diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index bd9c8c5..554be34 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -234,7 +234,7 @@ def create_mask(self, center, shape, node): # Annotate mask mask = np.zeros(shape) - #self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP + # self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP self.annotate_fragment(mask, subgraph, offset, fill_val=1) return mask diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index e0cca48..9c65b4a 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -66,6 +66,7 @@ def __init__( model_name, output_dir, device="cuda", + exp_name=None, lr=1e-3, max_epochs=200, min_recall=0, @@ -94,7 +95,8 @@ def __init__( Indication of whether to save MIPs of mistakes. Default is False. """ # Set experiment name - exp_name = "session-" + datetime.today().strftime("%Y%m%d_%H%M") + if exp_name is None: + exp_name = "session-" + datetime.today().strftime("%Y%m%d_%H%M") log_dir = os.path.join(output_dir, exp_name) util.mkdir(log_dir) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index d9b336a..3e97fd7 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -132,7 +132,7 @@ def set_merge_site_info(self): # Store fragment IDs corresponding to merge sites for xyz in self.merge_sites["xyz"]: _, ii = self.kdtree.query(xyz) - #self.ignore_fragments.add(self.node_component_id[ii]) + # self.ignore_fragments.add(self.node_component_id[ii]) def set_giant_components(self): for nodes in map(list, nx.connected_components(self.graph)): @@ -176,7 +176,9 @@ def get_random_nonmerge_site(self): # Try again n_attempts += 1 if n_attempts > 100: - print(f"Failed to find valid random nonmerge site for {self.brain_id}!") + print( + f"Failed to find valid random nonmerge site for {self.brain_id}!" + ) return util.sample_once(self.nodes) # --- Helpers --- @@ -324,21 +326,25 @@ def save_val_summary(self, output_dir): # Merge sites for _, row in brain_dataset.merge_sites.iterrows(): - rows.append({ - "brain_id": brain_id, - "swc_name": row["filename"], - "xyz": row["xyz"], - "label": "merge", - }) + rows.append( + { + "brain_id": brain_id, + "swc_name": row["filename"], + "xyz": row["xyz"], + "label": "merge", + } + ) # Nonmerge sites for _, row in brain_dataset.nonmerge_sites.iterrows(): - rows.append({ - "brain_id": brain_id, - "swc_name": row["filename"], - "xyz": row["xyz"], - "label": "nonmerge", - }) + rows.append( + { + "brain_id": brain_id, + "swc_name": row["filename"], + "xyz": row["xyz"], + "label": "nonmerge", + } + ) df = pd.DataFrame(rows) df.to_csv(os.path.join(output_dir, "val_summary.csv"), index=False) @@ -439,7 +445,7 @@ def __iter__(self): # Split into batches upfront batch_idx_groups = [ - idxs[start: min(start + self.batch_size, len(idxs))] + idxs[start : min(start + self.batch_size, len(idxs))] for start in range(0, len(idxs), self.batch_size) ] @@ -631,10 +637,12 @@ def create_dataset_collection( img_path = os.path.join(img_prefixes[brain_id], "0") segmentation_id = get_segmentation_id(sites_root_path, brain_id) sites_path = os.path.join(sites_root_path, brain_id, segmentation_id) - swcs_path = os.path.join(swcs_root_path, brain_id, segmentation_id, "fragments") - #util.get_google_swcs_prefix( + swcs_path = os.path.join( + swcs_root_path, brain_id, segmentation_id, "fragments" + ) + # util.get_google_swcs_prefix( # swcs_root_path, brain_id, segmentation_id - #) + # ) # Add dataset print(f" \nBrain ID [{i}/{len(brain_ids)}]: {brain_id}") diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index cc4833e..603b985 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -1090,9 +1090,7 @@ def path_length(self, path): """ if len(path) > 1: diffs = self.node_xyz[path[1:]] - self.node_xyz[path[:-1]] - return np.sum(np.sqrt(np.sum(diffs**2, axis=1))) - else: - return 0 + return np.linalg.norm(diffs**2, axis=1).sum() def path_thru_node(self, i, max_depth=np.inf): if self.degree[i] == 0: diff --git a/src/neuron_proofreader/utils/graph_util.py b/src/neuron_proofreader/utils/graph_util.py index dad3464..6a4ca7b 100644 --- a/src/neuron_proofreader/utils/graph_util.py +++ b/src/neuron_proofreader/utils/graph_util.py @@ -120,7 +120,9 @@ def __call__(self, swc_pointer): # Continue submitting processes if swc_dicts: - pending.add(executor.submit(self.load, swc_dicts.pop())) + pending.add( + executor.submit(self.load, swc_dicts.pop()) + ) return irreducibles def load(self, swc_dict): diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 00e828f..dc1e581 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -169,7 +169,9 @@ def read_swcs(self, swc_paths): """ with ThreadPoolExecutor() as executor: # Assign threads - threads = {executor.submit(self.read_swc, p) for p in swc_paths} + threads = { + executor.submit(self.read_swc, p) for p in swc_paths[0:1] + } pbar = self.manual_progress_bar(len(threads)) # Store results diff --git a/src/neuron_proofreader/utils/util.py b/src/neuron_proofreader/utils/util.py index 4056cbf..df62329 100644 --- a/src/neuron_proofreader/utils/util.py +++ b/src/neuron_proofreader/utils/util.py @@ -401,7 +401,7 @@ def parse_cloud_path(path): """ # Remove s3:// or gs:// if present if path.startswith("s3://") or path.startswith("gs://"): - path = path[len("s3://"):] + path = path[len("s3://") :] # Split path parts = path.split("/", 1) From 1e1a3d4a7cfeea587fadcddf4c3044d92995cfcf Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sat, 13 Jun 2026 02:15:49 +0000 Subject: [PATCH 04/10] refactor: removed test block --- src/neuron_proofreader/utils/swc_util.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index dc1e581..00e828f 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -169,9 +169,7 @@ def read_swcs(self, swc_paths): """ with ThreadPoolExecutor() as executor: # Assign threads - threads = { - executor.submit(self.read_swc, p) for p in swc_paths[0:1] - } + threads = {executor.submit(self.read_swc, p) for p in swc_paths} pbar = self.manual_progress_bar(len(threads)) # Store results From 64b89bcb21c17bbbc29280eab94c8500bbfa2be8 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sat, 13 Jun 2026 03:12:57 +0000 Subject: [PATCH 05/10] refactor: path sample with replacement --- .../geometric_learning/curve_datamodules.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/neuron_proofreader/geometric_learning/curve_datamodules.py b/src/neuron_proofreader/geometric_learning/curve_datamodules.py index 68aa29d..b6addb9 100644 --- a/src/neuron_proofreader/geometric_learning/curve_datamodules.py +++ b/src/neuron_proofreader/geometric_learning/curve_datamodules.py @@ -90,8 +90,8 @@ def __repr__(self): f"\n brain_id={self.brain_id}, " f"\n num_neurons={num_neurons}, " f"\n num_paths={len(self)}, " - f"\n min_path_length={np.min(lengths):.2f}, " - f"\n mean_path_length={np.mean(lengths):.2f}, " + f"\n min_length={np.min(lengths):.2f}, " + f"\n mean_length={np.mean(lengths):.2f}, " f"\n max_length={np.max(lengths):.2f}," f"\n)" ) @@ -193,7 +193,7 @@ def __init__(self, dataset, examples_per_epoch): def __iter__(self): idxs = self.dataset.examples_df.sample( - self.examples_per_epoch, weights="length" + self.examples_per_epoch, replace=True, weights="length" ).index return iter(np.array(idxs)) From ed4bc3de3d3b1bb7fb46e432cf77acef04871981 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sat, 13 Jun 2026 19:48:25 +0000 Subject: [PATCH 06/10] refactor: max path length --- .../geometric_learning/curve_datamodules.py | 22 ++++---- .../split_proofreading/split_inference.py | 54 +++++++++---------- 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/src/neuron_proofreader/geometric_learning/curve_datamodules.py b/src/neuron_proofreader/geometric_learning/curve_datamodules.py index b6addb9..c2fbf51 100644 --- a/src/neuron_proofreader/geometric_learning/curve_datamodules.py +++ b/src/neuron_proofreader/geometric_learning/curve_datamodules.py @@ -8,6 +8,7 @@ """ +from copy import deepcopy from torch.utils.data import Dataset, DataLoader, Sampler import networkx as nx @@ -37,8 +38,7 @@ def __init__( # Core data structures self.graph = self.load_skeletons(graph_config, swcs_path) - self.paths = self.irreducible_paths() - + self.paths = self.get_valid_paths() def load_skeletons(self, config, swcs_path): graph = SkeletonGraph( @@ -51,19 +51,17 @@ def load_skeletons(self, config, swcs_path): graph.load(swcs_path) return graph + def get_valid_paths(self): + paths = list() + for p in self.irreducible_paths(): + if self.path_length(p) < max_length: + paths.append(p) + return paths + # --- Get Examples --- def __getitem__(self, i): # Get path - path = self.paths[i].copy() - - # Check whether to subsample - if self.path_length(path) > self.max_length: - root = util.sample_once(path) - new_length = np.random.random() * self.max_length - path = self.path_thru_node(root, max_depth=new_length) - - # Check whether to transform - curve = self.node_xyz[path] + curve = deepcopy(self.node_xyz[self.paths[i]]) if self.transform: curve = self.transform(curve) diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index ed84f02..9bdd9a6 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -7,24 +7,19 @@ Code that executes the full split correction pipeline. Inference Pipeline: - 1. Graph Construction - Build graph from neuron fragments. - - 2. Proposal Generation + 1. Proposal Generation Generate proposals for potential connections between fragments. - 3. Proposal Classification + 2. Proposal Classification a. Feature Generation Extract features from proposals and graph for a machine learning model. b. Predict with Graph Neural Network (GNN) Run a GNN to classify proposals as accept/reject based on the learned features. - c. Merge Accepted Proposals - Add accepted proposals to the graph as edges. -Note: Steps 2 and 3 of the inference pipeline can be iterated in a loop that - repeats multiple times by calling the pipeline in a loop + 3. Merge Accepted Proposals + Add accepted proposals to the graph as edges. """ @@ -58,7 +53,9 @@ def __init__( img_path, output_dir, model, - config, + graph_config, + ml_config, + proposals_config, log_preamble="", soma_centroids=list(), ): @@ -74,9 +71,7 @@ def __init__( Path to whole-brain image corresponding to the given fragments. output_dir : str Directory where the results of the inference will be saved. - config : Config - Configuration object containing parameters and settings required - for the inference pipeline. + ... log_preamble : str, optional String to be added to the beginning of log. Default is an empty string. @@ -93,7 +88,7 @@ def __init__( # Logger util.mkdir(self.output_dir) - log_path = os.path.join(self.output_dir, "runtimes.txt") + log_path = os.path.join(self.output_dir, "summary.txt") self.log_handle = open(log_path, "a") self.log(log_preamble) @@ -406,19 +401,6 @@ def predict(self, data): hat_y = ml_util.tensor_to_list(hat_y) return {idx_to_id[i]: y_i for i, y_i in enumerate(hat_y)} - def save_graph(self, dirname): - # Set paths - temp_dir = os.path.join(self.output_dir, "temp") - output_zip_path = os.path.join(self.output_dir, dirname, "swcs.zip") - util.mkdir(temp_dir) - util.mkdir(os.path.join(self.output_dir, dirname)) - - # Save swcs - self.dataset.to_zipped_swcs_multithreaded(temp_dir) - zip_paths = util.list_paths(temp_dir, extension=".zip") - util.combine_zips(zip_paths, output_zip_path) - util.rmdir(temp_dir) - def save_proposal_results(self, preds_dict, suffix=""): summary = list() for proposal, pred in preds_dict.items(): @@ -468,3 +450,21 @@ def save_fragment_ids(self): path = f"{self.output_dir}/segment_ids.txt" segment_ids = list(self.dataset.component_id_to_swc_id.values()) util.write_list(path, segment_ids) + + +# --- Helpers --- +def create_dataset(): + pass + + +def save_graph(self, dataset, output_dir, dirname): + # Save graph across set of ZIPs + temp_dir = os.path.join(output_dir, "temp") + dataset.to_zipped_swcs_multithreaded(temp_dir) + + # Combine ZIPs into single ZIP + zip_paths = util.list_paths(temp_dir, extension=".zip") + final_zip_path = os.path.join(output_dir, dirname, "swcs.zip") + util.mkdir(os.path.join(output_dir, dirname)) + util.combine_zips(zip_paths, final_zip_path) + util.rmdir(temp_dir) From 48d33ce76b32aede0ff70ac65c472bf7aac58840 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sat, 13 Jun 2026 20:29:24 +0000 Subject: [PATCH 07/10] bug: max path length implementation --- src/neuron_proofreader/configs.py | 2 +- .../geometric_learning/curve_datamodules.py | 2 +- src/neuron_proofreader/inference_pipeline.py | 115 ++++++++++++++++++ src/neuron_proofreader/proposal_graph.py | 2 +- src/neuron_proofreader/skeleton_graph.py | 37 ++---- .../split_proofreading/split_inference.py | 18 --- 6 files changed, 127 insertions(+), 49 deletions(-) create mode 100644 src/neuron_proofreader/inference_pipeline.py diff --git a/src/neuron_proofreader/configs.py b/src/neuron_proofreader/configs.py index 90be82e..016f8fd 100644 --- a/src/neuron_proofreader/configs.py +++ b/src/neuron_proofreader/configs.py @@ -1,5 +1,5 @@ """ -Created on Frid Sept 15 16:00:00 2024 +Created on Fri Sept 15 16:00:00 2024 @author: Anna Grim @email: anna.grim@alleninstitute.org diff --git a/src/neuron_proofreader/geometric_learning/curve_datamodules.py b/src/neuron_proofreader/geometric_learning/curve_datamodules.py index c2fbf51..475818b 100644 --- a/src/neuron_proofreader/geometric_learning/curve_datamodules.py +++ b/src/neuron_proofreader/geometric_learning/curve_datamodules.py @@ -54,7 +54,7 @@ def load_skeletons(self, config, swcs_path): def get_valid_paths(self): paths = list() for p in self.irreducible_paths(): - if self.path_length(p) < max_length: + if self.path_length(p) < self.max_length: paths.append(p) return paths diff --git a/src/neuron_proofreader/inference_pipeline.py b/src/neuron_proofreader/inference_pipeline.py new file mode 100644 index 0000000..402c1e4 --- /dev/null +++ b/src/neuron_proofreader/inference_pipeline.py @@ -0,0 +1,115 @@ +""" +Created on Fri June 13 16:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code for running full neuron proofreading pipeline, including both split and +merge detection and correction. + +""" + +from neuron_proofreader.skeleton_graph import SkeletonGraph + + +class ProofreadPipeline: + + def __init__( + self, + fragments_path, + graph_config, + img_path, + output_dir, + log_preamble="", + soma_centroids=list(), + ): + """ + Initializes an object that executes the full split proofreading + pipeline. + + Parameters + ---------- + fragments_path : str + Path to SWC files to be loaded into graph. + graph_config : GraphConfig + Configuration object that contains parameters for building graph. + img_path : str + Path to whole-brain image corresponding to the given fragments. + output_dir : str + Directory where the results of the inference will be saved. + log_preamble : str, optional + String to be added to the beginning of log. Default is an empty + string. + soma_centroids : List[Tuple[float]], optional + Physical coordinates of soma centroids. Default is an empty list. + """ + # Instance attributes + self.img_path = img_path + self.output_dir = output_dir + self.soma_centroids = soma_centroids + + # Logger + util.mkdir(self.output_dir) + log_path = os.path.join(self.output_dir, "summary.txt") + self.log_handle = open(log_path, "a") + self.log(log_preamble) + + # Load data + self._load_data(fragments_path, img_path) + + def load_graph(self, fragments_path, config): + """ + Loads a graph from the given fragments. + + Parameters + ---------- + fragments_path : str + Path to SWC files to be loaded into graph. + config : GraphConfig + Configuration object that contains parameters for building graph. + """ + # Load data + t0 = time() + self.log("Step 1: Build Graph") + self.graph = ProposalGraph( + anisotropy=config.anisotropy, + min_cable_length=config.min_cable_length, + node_spacing=config.node_spacing, + verbose=config.verbose, + ) + + # Save original graph state + self.save_fragment_ids() + self.save_graph("original_swcs") + self.log("Initial Graph") + self.log(self.graph) + + # Report runtime + elapsed, unit = util.time_writer(time() - t0) + self.log(f"Module Runtime: {elapsed:.2f} {unit}\n") + + # --- Split Proofreading --- + def split_proofreading(self): + pass + + # --- Merge Proofreading --- + def merge_proofreading(self, mode): + if mode == "heuristic": + results = self.graph.remove_high_risk_merges() + self.log(results) + elif mode == "connected_somas": + results = self.graph.remove_soma_merges() + self.log(results) + + # --- Helpers --- + def save_graph(self, dirname): + # Save graph across set of ZIPs + temp_dir = os.path.join(self.output_dir, "temp") + self.graph.to_zipped_swcs_multithreaded(temp_dir) + + # Combine ZIPs into single ZIP + zip_paths = util.list_paths(temp_dir, extension=".zip") + final_zip_path = os.path.join(self.output_dir, dirname, "swcs.zip") + util.mkdir(os.path.join(self.output_dir, dirname)) + util.combine_zips(zip_paths, final_zip_path) + util.rmdir(temp_dir) diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index 7a9616e..238d53d 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -37,7 +37,6 @@ def __init__( self, anisotropy=(1.0, 1.0, 1.0), gt_path=None, - max_proposals_per_leaf=3, min_cable_length=0, node_spacing=1, prune_depth=20.0, @@ -135,6 +134,7 @@ def generate_proposals( self, search_radius, allow_nonleaf_proposals=False, + max_proposals_per_leaf=3, min_size_with_proposals=0, ): """ diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index 603b985..697638b 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -1219,38 +1219,19 @@ def soma_nodes(self): soma_nodes.append(i) return soma_nodes - def summary(self, prefix=""): - """ - Generate a human-readable summary of the graph. - - Parameters - ---------- - prefix : str, optional - Optional string to prepend to the summary title. - - Returns - ------- - summary : str - Formatted multi-line string containing: - - Graph Name - - Number of connected components - - Number of nodes - - Number of edges - - Memory consumption (in GBs) - """ - # Compute values + def __repr__(self): n_components = format(nx.number_connected_components(self), ",") n_nodes = format(self.number_of_nodes(), ",") n_edges = format(self.number_of_edges(), ",") memory = util.get_memory_usage() - - # Compile results - summary = [f"{prefix} Graph"] - summary.append(f"# Connected Components: {n_components}") - summary.append(f"# Nodes: {n_nodes}") - summary.append(f"# Edges: {n_edges}") - summary.append(f"Memory Consumption: {memory:.2f} GBs") - return "\n".join(summary) + return ( + f"SkeletonGraph(" + f"\n n_components={n_components}," + f"\n n_nodes={n_nodes}," + f"\n n_edges={n_edges}," + f"\n memory={memory:.2f} GBs" + f"\n)" + ) def swc_ids(self): """ diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index 9bdd9a6..e9e9eed 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -450,21 +450,3 @@ def save_fragment_ids(self): path = f"{self.output_dir}/segment_ids.txt" segment_ids = list(self.dataset.component_id_to_swc_id.values()) util.write_list(path, segment_ids) - - -# --- Helpers --- -def create_dataset(): - pass - - -def save_graph(self, dataset, output_dir, dirname): - # Save graph across set of ZIPs - temp_dir = os.path.join(output_dir, "temp") - dataset.to_zipped_swcs_multithreaded(temp_dir) - - # Combine ZIPs into single ZIP - zip_paths = util.list_paths(temp_dir, extension=".zip") - final_zip_path = os.path.join(output_dir, dirname, "swcs.zip") - util.mkdir(os.path.join(output_dir, dirname)) - util.combine_zips(zip_paths, final_zip_path) - util.rmdir(temp_dir) From 95ac7f24861b72d5922b52e7f5d0c48f11c2a17b Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sun, 14 Jun 2026 21:24:29 +0000 Subject: [PATCH 08/10] feat: automated proofreading pipeline --- src/neuron_proofreader/configs.py | 11 +- .../geometric_learning/curve_datamodules.py | 1 - src/neuron_proofreader/inference_pipeline.py | 115 -------- .../machine_learning/gnn_models.py | 2 +- .../proofreading_pipeline.py | 208 +++++++++++++++ src/neuron_proofreader/proposal_graph.py | 40 +-- src/neuron_proofreader/skeleton_graph.py | 24 +- .../split_proofreading/proposal_generation.py | 29 +-- .../split_proofreading/split_datasets.py | 76 ++---- .../split_feature_extraction.py | 8 +- .../split_proofreading/split_inference.py | 245 ++++++------------ src/neuron_proofreader/utils/img_util.py | 15 +- 12 files changed, 370 insertions(+), 404 deletions(-) delete mode 100644 src/neuron_proofreader/inference_pipeline.py create mode 100644 src/neuron_proofreader/proofreading_pipeline.py diff --git a/src/neuron_proofreader/configs.py b/src/neuron_proofreader/configs.py index 016f8fd..3a8e108 100644 --- a/src/neuron_proofreader/configs.py +++ b/src/neuron_proofreader/configs.py @@ -15,6 +15,7 @@ import os +from neuron_proofreader.machine_learning.augmentation import ImageTransforms from neuron_proofreader.utils import util @@ -97,13 +98,14 @@ class ImageConfig(Config): """ brightness_clip: int = 400 + img_path: str = None name: str = "image_config" percentiles: Tuple[float, float] = (1, 99.5) patch_shape: Tuple[int, int, int] = (128, 128, 128) - transform: bool = False + transform = None def set_train_mode(self): - self.transform = True + self.transform = ImageTransforms() @dataclass @@ -124,5 +126,8 @@ class ProposalsConfig(Config): """ allow_nonleaf_proposals: bool = False - proposals_per_leaf: int = 3 + max_proposals_per_leaf: int = 3 + min_size_with_proposals: float = 0 trim_endpoints_bool: bool = True + search_radius: float = 25 + search_scaling_factor: float = 1.5 diff --git a/src/neuron_proofreader/geometric_learning/curve_datamodules.py b/src/neuron_proofreader/geometric_learning/curve_datamodules.py index 475818b..e957f44 100644 --- a/src/neuron_proofreader/geometric_learning/curve_datamodules.py +++ b/src/neuron_proofreader/geometric_learning/curve_datamodules.py @@ -17,7 +17,6 @@ import torch from neuron_proofreader.skeleton_graph import SkeletonGraph -from neuron_proofreader.utils import util # --- Dataset Classes --- diff --git a/src/neuron_proofreader/inference_pipeline.py b/src/neuron_proofreader/inference_pipeline.py deleted file mode 100644 index 402c1e4..0000000 --- a/src/neuron_proofreader/inference_pipeline.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -Created on Fri June 13 16:00:00 2026 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Code for running full neuron proofreading pipeline, including both split and -merge detection and correction. - -""" - -from neuron_proofreader.skeleton_graph import SkeletonGraph - - -class ProofreadPipeline: - - def __init__( - self, - fragments_path, - graph_config, - img_path, - output_dir, - log_preamble="", - soma_centroids=list(), - ): - """ - Initializes an object that executes the full split proofreading - pipeline. - - Parameters - ---------- - fragments_path : str - Path to SWC files to be loaded into graph. - graph_config : GraphConfig - Configuration object that contains parameters for building graph. - img_path : str - Path to whole-brain image corresponding to the given fragments. - output_dir : str - Directory where the results of the inference will be saved. - log_preamble : str, optional - String to be added to the beginning of log. Default is an empty - string. - soma_centroids : List[Tuple[float]], optional - Physical coordinates of soma centroids. Default is an empty list. - """ - # Instance attributes - self.img_path = img_path - self.output_dir = output_dir - self.soma_centroids = soma_centroids - - # Logger - util.mkdir(self.output_dir) - log_path = os.path.join(self.output_dir, "summary.txt") - self.log_handle = open(log_path, "a") - self.log(log_preamble) - - # Load data - self._load_data(fragments_path, img_path) - - def load_graph(self, fragments_path, config): - """ - Loads a graph from the given fragments. - - Parameters - ---------- - fragments_path : str - Path to SWC files to be loaded into graph. - config : GraphConfig - Configuration object that contains parameters for building graph. - """ - # Load data - t0 = time() - self.log("Step 1: Build Graph") - self.graph = ProposalGraph( - anisotropy=config.anisotropy, - min_cable_length=config.min_cable_length, - node_spacing=config.node_spacing, - verbose=config.verbose, - ) - - # Save original graph state - self.save_fragment_ids() - self.save_graph("original_swcs") - self.log("Initial Graph") - self.log(self.graph) - - # Report runtime - elapsed, unit = util.time_writer(time() - t0) - self.log(f"Module Runtime: {elapsed:.2f} {unit}\n") - - # --- Split Proofreading --- - def split_proofreading(self): - pass - - # --- Merge Proofreading --- - def merge_proofreading(self, mode): - if mode == "heuristic": - results = self.graph.remove_high_risk_merges() - self.log(results) - elif mode == "connected_somas": - results = self.graph.remove_soma_merges() - self.log(results) - - # --- Helpers --- - def save_graph(self, dirname): - # Save graph across set of ZIPs - temp_dir = os.path.join(self.output_dir, "temp") - self.graph.to_zipped_swcs_multithreaded(temp_dir) - - # Combine ZIPs into single ZIP - zip_paths = util.list_paths(temp_dir, extension=".zip") - final_zip_path = os.path.join(self.output_dir, dirname, "swcs.zip") - util.mkdir(os.path.join(self.output_dir, dirname)) - util.combine_zips(zip_paths, final_zip_path) - util.rmdir(temp_dir) diff --git a/src/neuron_proofreader/machine_learning/gnn_models.py b/src/neuron_proofreader/machine_learning/gnn_models.py index b5add00..95f1387 100644 --- a/src/neuron_proofreader/machine_learning/gnn_models.py +++ b/src/neuron_proofreader/machine_learning/gnn_models.py @@ -39,7 +39,7 @@ def __init__( self, patch_shape, disable_msg_passing=False, - heads=2, + heads=4, hidden_dim=128, n_layers=2, ): diff --git a/src/neuron_proofreader/proofreading_pipeline.py b/src/neuron_proofreader/proofreading_pipeline.py new file mode 100644 index 0000000..4528978 --- /dev/null +++ b/src/neuron_proofreader/proofreading_pipeline.py @@ -0,0 +1,208 @@ +""" +Created on Fri June 13 16:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code for running full neuron proofreading pipeline, including both split and +merge detection and correction. + +""" + +from time import time + +import numpy as np +import os + +from neuron_proofreader.proposal_graph import ProposalGraph +from neuron_proofreader.split_proofreading.split_inference import ( + SplitProofreader, +) +from neuron_proofreader.utils import geometry_util, swc_util, util + + +class ProofreadPipeline: + + def __init__( + self, + swcs_path, + graph_config, + img_config, + output_dir, + device="cuda", + log_preamble="", + soma_centroids=list(), + ): + """ + Initializes an object that executes the full split proofreading + pipeline. + + Parameters + ---------- + swcs_path : str + Path to SWC files to be loaded into graph. + graph_config : GraphConfig + Config object that contains parameters for building graph. + img_config : ImageConfig + Config object that contains parameters for processing images. + output_dir : str + Directory where the results of the inference will be saved. + log_preamble : str, optional + String to be added to the beginning of log. Default is an empty + string. + soma_centroids : List[Tuple[float]], optional + Physical coordinates of soma centroids. Default is an empty list. + """ + # Instance attributes + self.device = device + self.img_config = img_config + self.output_dir = output_dir + self.step_cnt = 0 + + # Logger + util.mkdir(self.output_dir) + log_path = os.path.join(self.output_dir, "summary.txt") + self.log_handle = open(log_path, "a") + self.log(log_preamble) + + # Load data + self.load_graph(graph_config, swcs_path, soma_centroids) + + def load_graph(self, config, swcs_path, soma_centroids): + """ + Loads a graph from the given fragments. + + Parameters + ---------- + swcs_path : str + Path to SWC files to be loaded into graph. + config : GraphConfig + Configuration object that contains parameters for building graph. + """ + # Load data + t0 = time() + self.step_cnt += 1 + self.log(f"Step {self.step_cnt}: Build Graph") + self.graph = ProposalGraph( + anisotropy=config.anisotropy, + min_cable_length=config.min_cable_length, + node_spacing=config.node_spacing, + verbose=config.verbose, + ) + self.graph.load(swcs_path) + self.graph.load_somas(soma_centroids) + + # Remove doubled fragments + if config.remove_doubles: + geometry_util.remove_doubles(self.graph, 200) + + # Save original graph state + self.save_graph("original_swcs") + self.log("\nInitial Graph...") + self.log(self.graph.__repr__()) + + # Report runtime + elapsed, unit = util.time_writer(time() - t0) + self.log(f"Module Runtime: {elapsed:.2f} {unit}\n") + + # --- Split Proofreading --- + def split_proofreading( + self, + model, + proposals_config, + batch_size=32, + dt=0.05, + min_threshold=0.8, + removal_threshold=0.3, + save_result=True, + ): + # Create proofreader + proofreader = SplitProofreader( + self.graph, + model, + self.img_config, + self.output_dir, + batch_size=batch_size, + device=self.device, + log_handle=self.log_handle, + ) + + # Run inference + self.step_cnt += 1 + self.log(f"\nStep {self.step_cnt}: Split Proofreading") + proofreader( + proposals_config, + dt=dt, + min_threshold=min_threshold, + removal_threshold=removal_threshold, + ) + + # Save final graph + if save_result: + self.log("Final Graph...") + self.log(self.graph.__repr__()) + self.reconfigure_node_radius() + self.save_graph("corrected_swcs") + + # --- Merge Proofreading --- + def merge_proofreading(self, mode): + # Report step + self.step_cnt += 1 + self.log( + f"\nStep {self.step_cnt}: Merge Proofreading with mode={mode}" + ) + + # Detect merges + if mode == "heuristic": + merge_sites = self.graph.remove_high_risk_merges() + elif mode == "connected_somas": + results = self.graph.remove_soma_merges() + self.log(results) + + # Report results + self.log(f"# Merges Detected: {len(merge_sites)}") + + # Save sites + color = "# COLOR 1.0 0.0 0.0" + zip_path = os.path.join(self.output_dir, f"{mode}_merge_sites.zip") + swc_util.write_points( + zip_path, merge_sites, color=color, prefix="merge_site", radius=10 + ) + + # --- Helpers --- + def log(self, txt): + """ + Logs and prints the given text. + + Parameters + ---------- + txt : str + Text to be logged and printed. + """ + print(txt) + self.log_handle.write(txt) + self.log_handle.write("\n") + + def reconfigure_node_radius(self): + n_nodes = len(self.graph.node_radius) + self.graph.node_radius = np.ones((n_nodes), dtype=np.float16) + for i, j in self.graph.accepts: + self.graph.node_radius[i] = 6 + self.graph.node_radius[j] = 6 + + def save_fragment_ids(self): + path = f"{self.output_dir}/segment_ids.txt" + segment_ids = list(self.graph.component_id_to_swc_id.values()) + util.write_list(path, segment_ids) + + def save_graph(self, dirname): + # Save graph across set of ZIPs + temp_dir = os.path.join(self.output_dir, "temp") + self.graph.to_zipped_swcs_multithreaded(temp_dir) + + # Combine ZIPs into single ZIP + zip_paths = util.list_paths(temp_dir, extension=".zip") + final_zip_path = os.path.join(self.output_dir, dirname, "swcs.zip") + util.mkdir(os.path.join(self.output_dir, dirname)) + util.combine_zips(zip_paths, final_zip_path) + util.rmdir(temp_dir) diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index 238d53d..f4bf813 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -22,7 +22,7 @@ ProposalGenerator, trim_proposal_endpoints, ) -from neuron_proofreader.utils import geometry_util, graph_util +from neuron_proofreader.utils import geometry_util class ProposalGraph(SkeletonGraph): @@ -62,36 +62,22 @@ def __init__( graph. Default is True. """ # Call parent class - super().__init__() - - # Instance attributes - Graph - self.anisotropy = anisotropy - self.component_id_to_swc_id = dict() - self.gt_path = gt_path - self.soma_component_ids = set() - self.verbose = verbose + super().__init__( + anisotropy=anisotropy, + min_cable_length=min_cable_length, + node_spacing=node_spacing, + prune_depth=prune_depth, + verbose=verbose, + ) # Instance attributes - Proposals self.accepts = set() self.gt_accepts = set() + self.gt_path = gt_path self.merged_ids = set() self.n_merges_blocked = 0 self.n_proposals_blocked = 0 - self.reset_proposals() - self.proposal_generator = ProposalGenerator( - self, - max_proposals_per_leaf=max_proposals_per_leaf, - ) - - # Graph Loader - self.graph_loader = graph_util.GraphLoader( - anisotropy=anisotropy, - min_cable_length=min_cable_length, - node_spacing=node_spacing, - prune_depth=prune_depth, - verbose=verbose, - ) # --- Update Structure --- def relabel_nodes(self): @@ -153,11 +139,13 @@ def generate_proposals( """ # Proposal generation assert len(self.kdtree.data) == self.number_of_nodes() - proposals = self.proposal_generator( - search_radius, + proposal_generator = ProposalGenerator( + self, allow_nonleaf_proposals=allow_nonleaf_proposals, + max_proposals_per_leaf=max_proposals_per_leaf, min_size_with_proposals=min_size_with_proposals, ) + proposals = proposal_generator(search_radius) self.search_radius = search_radius self.store_proposals(proposals) @@ -328,7 +316,7 @@ def proposal_length(self, proposal): return self.dist(*tuple(proposal)) def proposal_midpoint(self, proposal): - return geometry_util.midpoint(*self.proposal_xyz(proposal)) + return self.midpoint(*proposal) def proposal_radius(self, proposal): i, j = proposal diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index 697638b..5c4248a 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -85,6 +85,7 @@ def __init__( self.node_spacing = node_spacing self.soma_centroids = list() self.soma_component_ids = list() + self.verbose = verbose # Graph Loader anisotropy = anisotropy if use_anisotropy else (1.0, 1.0, 1.0) @@ -901,6 +902,9 @@ def leaf_nodes(self): """ return [i for i in self.nodes if self.degree[i] == 1] + def midpoint(self, i, j): + return geometry_util.midpoint(self.node_xyz[i], self.node_xyz[j]) + def node_local_voxel(self, node, offset): """ Computes the local voxel coordinate of the given node within the image @@ -1124,8 +1128,8 @@ def remove_high_risk_merges(self, max_dist=7): somas_kdtree = KDTree(self.node_xyz[soma_nodes]) # Iterate over branching nodes - cnt = 0 rm_nodes = set() + merge_sites = list() while branching_nodes: # Set root of search root = branching_nodes.pop() @@ -1145,6 +1149,7 @@ def remove_high_risk_merges(self, max_dist=7): i, dist_i = queue.pop() if self.degree[i] > 2 and i != root: hit_branching_nodes.add(i) + merge_sites.append(self.midpoint(root, i)) # Update queue for j in self.neighbors(i): @@ -1157,9 +1162,10 @@ def remove_high_risk_merges(self, max_dist=7): if hit_branching_nodes or self.degree(root) > 3: rm_nodes = rm_nodes.union(visited) branching_nodes -= hit_branching_nodes - cnt += 1 + + # Update graph self.remove_nodes(rm_nodes) - return f"# High Risk Merges: {cnt}" + return merge_sites def rooted_subgraph(self, root, radius): """ @@ -1225,12 +1231,12 @@ def __repr__(self): n_edges = format(self.number_of_edges(), ",") memory = util.get_memory_usage() return ( - f"SkeletonGraph(" - f"\n n_components={n_components}," - f"\n n_nodes={n_nodes}," - f"\n n_edges={n_edges}," - f"\n memory={memory:.2f} GBs" - f"\n)" + f" SkeletonGraph(\n" + f" num_connected_components={n_components},\n" + f" num_nodes={n_nodes},\n" + f" num_edges={n_edges},\n" + f" memory={memory:.2f} GBs,\n" + f" )" ) def swc_ids(self): diff --git a/src/neuron_proofreader/split_proofreading/proposal_generation.py b/src/neuron_proofreader/split_proofreading/proposal_generation.py index 729597c..9a70cc6 100644 --- a/src/neuron_proofreader/split_proofreading/proposal_generation.py +++ b/src/neuron_proofreader/split_proofreading/proposal_generation.py @@ -24,8 +24,10 @@ class ProposalGenerator: def __init__( self, graph, + allow_nonleaf_proposals=False, max_attempts=2, max_proposals_per_leaf=3, + min_size_with_proposals=0, search_scaling_factor=1.5, ): """ @@ -35,29 +37,31 @@ def __init__( ---------- graph : ProposalGraph Graph that proposals will be generated for. + allow_nonleaf_proposals : bool, optional + Indication of whether to generate proposals between leaf and nodes + with degree 2. Default is False. max_attempts : int, optional Number of attempts made to generate proposals from a node with increasing search radii. Default is 2. max_proposals_per_leaf : bool, optional Maximum number of proposals generated at each leaf. Default is 3. + min_size_with_proposals : float, optional + Minimum cable path length required for fragments that proposals + are generated from. Default is 0. search_scaling_factor : 1.5, optional Scaling actor used to enlarge search radius for each search. Default is 2. """ # Instance attributes - self.allow_nonleaf_proposals = None + self.allow_nonleaf_proposals = allow_nonleaf_proposals self.graph = graph self.kdtree = None self.max_attempts = max_attempts self.max_proposals_per_leaf = max_proposals_per_leaf + self.min_size_with_proposals = min_size_with_proposals self.search_scaling_factor = search_scaling_factor - def __call__( - self, - initial_radius, - allow_nonleaf_proposals=False, - min_size_with_proposals=0, - ): + def __call__(self, initial_radius): """ Generates edge proposals between fragments within the given search radius. @@ -67,15 +71,8 @@ def __call__( initial_radius : float Initial search radius used to generate proposals between endpoints of proposal. - allow_nonleaf_proposals : bool, optional - Indication of whether to generate proposals between leaf and nodes - with degree 2. Default is False. - min_size_with_proposals : float, optional - Minimum cable path length required for fragments that proposals - are generated from. Default is 0. """ # Initializations - self.allow_nonleaf_proposals = allow_nonleaf_proposals self.set_kdtree() iterator = self.graph.leaf_nodes() if self.graph.verbose: @@ -87,9 +84,9 @@ def __call__( for leaf in iterator: # Check if fragment satisfies size requirement length = self.graph.cable_length( - max_depth=min_size_with_proposals, root=leaf + max_depth=self.min_size_with_proposals, root=leaf ) - if length < min_size_with_proposals: + if length < self.min_size_with_proposals: continue # Generate proposals diff --git a/src/neuron_proofreader/split_proofreading/split_datasets.py b/src/neuron_proofreader/split_proofreading/split_datasets.py index ad81b22..635ef53 100644 --- a/src/neuron_proofreader/split_proofreading/split_datasets.py +++ b/src/neuron_proofreader/split_proofreading/split_datasets.py @@ -17,8 +17,6 @@ import os import pandas as pd -from neuron_proofreader.proposal_graph import ProposalGraph -from neuron_proofreader.machine_learning.augmentation import ImageTransforms from neuron_proofreader.machine_learning.subgraph_sampler import ( SubgraphSampler, ) @@ -26,10 +24,10 @@ FeaturePipeline, HeteroGraphData, ) -from neuron_proofreader.utils import geometry_util, util +from neuron_proofreader.utils import util -# --- Single Brain Dataset --- +# --- Datasets --- class FragmentsDataset(IterableDataset): """ A dataset object that contains a graph built from fragments corresponding @@ -39,13 +37,11 @@ class FragmentsDataset(IterableDataset): def __init__( self, - fragments_path, - img_path, - config, + fragments_graph, + img_config, + batch_size=32, gt_path=None, - metadata_path=None, prefetch=4, - soma_centroids=set(), ): """ Instantiates a FragmentsDataset object. @@ -56,65 +52,26 @@ def __init__( Path to predicted SWC files to be loaded. img_path : str Path to the raw image associated with the fragments. - config : Config - Configuration object containing parameters and settings. + graph_config : GraphConfig + ... gt_path : str, optional Path to ground-truth SWC files to be loaded. Default is None. - metadata_path : str, optional - Patch to JSON file containing metadata on block that fragments - were extracted from. Default is None. - soma_centroids : List[Tuple[int]], optional - Phyiscal coordinates of soma centroids. Default is an empty list. """ # Instance attributes - self.config = config + self.batch_size = batch_size + self.graph = fragments_graph self.gt_path = gt_path self.prefetch = prefetch - self.transform = ImageTransforms() if config.ml.transform else False - - # Build graph - self._load_graph(fragments_path, metadata_path) - self.graph.load_somas(soma_centroids) + self.transform = img_config.transform # Feature extractor self.feature_extractor = FeaturePipeline( self.graph, - img_path, - brightness_clip=self.config.ml.brightness_clip, - patch_shape=self.config.ml.patch_shape, + img_config.img_path, + brightness_clip=img_config.brightness_clip, + patch_shape=img_config.patch_shape, ) - def _load_graph(self, fragments_path, metadata_path=None): - """ - Loads a graph by reading and processing SWC files specified by the - given path. - - Parameters - ---------- - fragments_path : str - Path to SWC files to be loaded. - metadata_path : str, optional - Patch to JSON file containing metadata on block that fragments - were extracted from. Default is None. - """ - # Build graph - self.graph = ProposalGraph( - anisotropy=self.config.graph.anisotropy, - gt_path=self.gt_path, - min_cable_length=self.config.graph.min_cable_length, - node_spacing=self.config.graph.node_spacing, - prune_depth=self.config.graph.prune_depth, - verbose=self.config.graph.verbose, - ) - self.graph.load(fragments_path) - - # Post process fragments - if metadata_path: - self.graph.clip_to_bbox(metadata_path) - - if self.config.graph.remove_doubles: - geometry_util.remove_doubles(self.graph, 200) - # --- Get Data --- def __iter__(self): """ @@ -137,18 +94,17 @@ def __getattr__(self, name): def get_sampler(self): """ - Gets a subgraph sampler that is used to iterate over dataset. + Gets a subgraph sampler used to iterate over dataset. Returns ------- sampler : SubgraphSampler Subgraph sampler that is used to iterate over dataset. """ - batch_size = self.config.ml.batch_size - return iter(SubgraphSampler(self.graph, max_proposals=batch_size)) + sampler = SubgraphSampler(self.graph, max_proposals=self.batch_size) + return iter(sampler) -# --- Multi-Brain Dataset --- class FragmentsDatasetCollection(IterableDataset): """ A dataset class for storing a set of FragmentDataset objects corresponding diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index 7976898..7330221 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -18,6 +18,9 @@ import numpy as np import torch +from neuron_proofreader.machine_learning.image_dataloader import ( + TensorStoreImage, +) from neuron_proofreader.utils import geometry_util, graph_util, img_util, util from neuron_proofreader.utils.ml_util import TensorDict @@ -242,15 +245,12 @@ def __init__( Number of voxels to be added in each dimension from start and end point of proposal for image patch extraction. Default is 40. """ - # Instance attributes self.brightness_clip = brightness_clip self.graph = graph + self.img = TensorStoreImage(img_path) self.patch_shape = patch_shape self.padding = padding - # Image reader - self.img = img_util.TensorStoreImage(img_path) - def __call__(self, subgraph, features): """ Extracts image patches and profiles for each proposal in the graph. diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index e9e9eed..89320d1 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -19,7 +19,7 @@ based on the learned features. 3. Merge Accepted Proposals - Add accepted proposals to the graph as edges. + Add accepted proposals as edges to the graph. """ @@ -27,7 +27,6 @@ from tqdm import tqdm import networkx as nx -import numpy as np import pandas as pd import os import torch @@ -38,26 +37,20 @@ from neuron_proofreader.utils import ml_util, util -class InferencePipeline: +class SplitProofreader: """ - Class that executes the full split proofreader inference pipeline by - performing the following steps: - (1) Graph Construction - (2) Proposal Generation - (3) Proposal Classification. + Class that executes the full split proofreader inference pipeline. """ def __init__( self, - fragments_path, - img_path, - output_dir, + graph, model, - graph_config, - ml_config, - proposals_config, - log_preamble="", - soma_centroids=list(), + img_config, + output_dir, + batch_size=32, + device="cuda", + log_handle=None, ): """ Initializes an object that executes the full split correction @@ -65,95 +58,46 @@ def __init__( Parameters ---------- - fragments_path : str - Path to SWC files to be loaded into graph. - img_path : str - Path to whole-brain image corresponding to the given fragments. - output_dir : str - Directory where the results of the inference will be saved. - ... - log_preamble : str, optional - String to be added to the beginning of log. Default is an empty - string. - soma_centroids : List[Tuple[float]], optional - Physcial coordinates of soma centroids. Default is an empty list. + ... """ # Instance attributes - self.accepted_proposals = list() - self.config = config - self.img_path = img_path - self.model = model.to(config.ml.device) + self.dataset = FragmentsDataset( + graph, + img_config, + batch_size=batch_size, + ) + self.device = device + self.model = model self.output_dir = output_dir - self.soma_centroids = soma_centroids # Logger - util.mkdir(self.output_dir) log_path = os.path.join(self.output_dir, "summary.txt") - self.log_handle = open(log_path, "a") - self.log(log_preamble) - - # Load data - self._load_data(fragments_path, img_path) + self.log_handle = log_handle or open(log_path, "a") - def _load_data(self, fragments_path, img_path): - """ - Builds a graph from the given fragments. - - Parameters - ---------- - fragments_path : str - Path to SWC files to be loaded into graph. - img_path : str - Path to whole-brain image corresponding to the given fragments. - """ - # Load data - t0 = time() - self.log("Step 1: Build Graph") - self.dataset = FragmentsDataset( - fragments_path, - img_path, - self.config, - soma_centroids=self.soma_centroids, - ) - self.log(self.dataset.summary(prefix="\nInitial")) - self.save_fragment_ids() - self.save_graph("original_swcs") - - # Postprocess fragments with somas - self.log(self.dataset.remove_soma_merges()) - self.log(self.dataset.connect_soma_fragments()) - - # Break high risk merges (if applicable) - if self.config.graph.remove_high_risk_merges: - self.log(self.dataset.remove_high_risk_merges()) - self.log(self.dataset.summary(prefix="\nPre-Corrected")) - self.save_graph("precorrected_swcs") - - # Report runtime - elapsed, unit = util.time_writer(time() - t0) - self.log(f"Module Runtime: {elapsed:.2f} {unit}\n") - - # --- Pipelines --- def __call__( - self, search_radius, dt=0.1, min_threshold=0.75, removal_threshold=0.3 + self, + proposals_config, + dt=0.1, + min_threshold=0.8, + removal_threshold=0.3, ): """ Executes the full inference pipeline. Parameters ---------- - search_radius : float - Search radius (in microns) used to generate proposals. + proposals_config : ProposalsConfig + Config object with settings for proposal generation. dt : float, optional Increment that acceptance threshold is lowered by. Default is 0.1. min_threshold : float, optional - Minimum threshold for accepting proposals. Default is 0.75. + Minimum threshold for accepting proposals. Default is 0.8. removal_threshold : float, optional Proposals with model predictions less than this value are removed. Default is 0.3. """ # Generate proposals - self.generate_proposals(search_radius) + self.generate_proposals(proposals_config) total_proposals = self.dataset.n_proposals() # Run inference @@ -168,80 +112,63 @@ def __call__( # Generate predictions while self.dataset.proposals: - # Generate predictons + # Generate proposal predictons cnt += 1 self.log( - f"\nThreshold={new_threshold} w/ only_leaf2leaf={only_leaf2leaf}" + f"\n--- Threshold={new_threshold} w/ only_leaf2leaf={only_leaf2leaf} ---" ) preds = self.predict_proposals( suffix=f"{name}_round={cnt}_threshold={new_threshold}" ) - # Merge accetped proposals + # Merge accepted proposals cur_threshold = new_threshold self.merge_with_threshold_schedule( - preds, cur_threshold, only_leaf2leaf=only_leaf2leaf + preds, cur_threshold, dt=dt, only_leaf2leaf=only_leaf2leaf ) - self.filter_proposals(preds, removal_threshold) - # Update threshold + # Remove rejected proposals + self.remove_proposals(preds, removal_threshold) + + # Update acceptance threshold new_threshold = max(cur_threshold - dt, min_threshold) if cur_threshold == new_threshold: break # Report results t, unit = util.time_writer(time() - t0) - p_accepts = len(self.dataset.accepts) / total_proposals - self.log(self.dataset.summary(prefix="\nFinal")) - self.log(f"Overall Acceptance Rate: {p_accepts:.2f}") + p_accepts = 100 * len(self.dataset.accepts) / total_proposals + self.log(f"Overall Accepted: {p_accepts:.2f}%") self.log(f"Total Runtime: {t:.2f} {unit}\n") - self.save_results() + self.save_connections() # --- Core Routines --- - def filter_proposals(self, preds, threshold): - # Remove based on model predictions and mergeability - cnt = 0 - for proposal, pred in preds.items(): - is_valid = self.dataset.is_mergeable(*proposal) - if pred < threshold or not is_valid: - self.dataset.remove_proposal(proposal) - cnt += 1 - - # Sanity check - for proposal in self.dataset.list_proposals(): - i, j = proposal - if self.dataset.degree[i] > 2 or self.dataset.degree[j] > 2: - self.dataset.remove_proposal(proposal) - cnt += 1 - - self.log("Filter Proposals") - self.log(f"# Proposals Removed: {cnt}") - self.log(f"# Proposals Remaining: {self.dataset.n_proposals()}\n") - - def generate_proposals(self, search_radius): + def generate_proposals(self, proposals_config): """ Generates proposals for the fragments graph based on the specified configuration. Parameters ---------- - search_radius : float - Search radius (in microns) used to generate proposals. + proposals_config : ProposalsConfig + Config object with settings for proposal generation. """ # Main t0 = time() - self.log("\nStep 2: Generate Proposals") - self.log(f"Search Radius: {search_radius}") + self.log("\nGenerate Proposals...") self.dataset.generate_proposals( - search_radius, - allow_nonleaf_proposals=self.config.graph.allow_nonleaf_proposals, + proposals_config.search_radius, + allow_nonleaf_proposals=proposals_config.allow_nonleaf_proposals, + max_proposals_per_leaf=proposals_config.max_proposals_per_leaf, + min_size_with_proposals=proposals_config.min_size_with_proposals, ) + # Report results n_proposals = format(self.dataset.n_proposals(), ",") n_proposals_blocked = self.dataset.n_proposals_blocked - - # Report results t, unit = util.time_writer(time() - t0) + + self.log(f"Search Radius: {proposals_config.search_radius}") self.log(f"# Proposals: {n_proposals}") self.log(f"# Proposals Blocked: {n_proposals_blocked}") self.log(f"Module Runtime: {t:.2f} {unit}\n") @@ -268,7 +195,6 @@ def merge_with_threshold_schedule( """ # Initializations t0 = time() - self.log("\nStep 3: Run Inference") n_proposals = self.dataset.n_proposals() n_accepts = 0 @@ -288,6 +214,7 @@ def merge_with_threshold_schedule( # Report results t, unit = util.time_writer(time() - t0) + self.log("Inference...") self.log(f"# Merges Blocked: {self.dataset.n_merges_blocked}") self.log(f"# Accepted: {format(n_accepts, ',')}") self.log(f"% Accepted: {100 * n_accepts / n_proposals:.2f}") @@ -311,13 +238,13 @@ def predict_proposals(self, suffix=""): pbar.update(data.n_proposals()) # Save results - self.save_proposal_results(preds, suffix=suffix) + self.save_model_predictions(preds, suffix=suffix) return preds def merge_proposals(self, preds, threshold, only_leaf2leaf=False): """ - Merges nodes corresponding to for proposals that satify the threshold - and no loop creation requirements. + Merges proposals with model prediction above threshold and does + not create a loop. Parameters ---------- @@ -350,16 +277,25 @@ def merge_proposals(self, preds, threshold, only_leaf2leaf=False): del preds[proposal] return n_accepts - def save_results(self): - """ - Saves the processed results from running the inference pipeline, - namely the corrected SWC files and a list of the merged SWC ids. - """ - self.reconfigure_node_radius() - self.save_graph("corrected_swcs") - self.save_connections() - self.config.save(self.output_dir) - self.log_handle.close() + def remove_proposals(self, preds, threshold): + # Remove based on model predictions and mergeability + cnt = 0 + for proposal, pred in preds.items(): + is_valid = self.dataset.is_mergeable(*proposal) + if pred < threshold or not is_valid: + self.dataset.remove_proposal(proposal) + cnt += 1 + + # Sanity check + for proposal in self.dataset.list_proposals(): + i, j = proposal + if self.dataset.degree[i] > 2 or self.dataset.degree[j] > 2: + self.dataset.remove_proposal(proposal) + cnt += 1 + + self.log("Remove Proposals...") + self.log(f"# Proposals Removed: {cnt}") + self.log(f"# Proposals Remaining: {self.dataset.n_proposals()}\n") # --- Helpers --- def log(self, txt): @@ -391,8 +327,7 @@ def predict(self, data): """ # Generate predictions with torch.inference_mode(): - device = self.config.ml.device - x = data.get_inputs().to(device) + x = data.get_inputs().to(self.device) with torch.cuda.amp.autocast(enabled=True): hat_y = torch.sigmoid(self.model(x)) @@ -401,7 +336,17 @@ def predict(self, data): hat_y = ml_util.tensor_to_list(hat_y) return {idx_to_id[i]: y_i for i, y_i in enumerate(hat_y)} - def save_proposal_results(self, preds_dict, suffix=""): + def save_connections(self): + """ + Writes accepted proposals to a text file. Each line contains the two + SWC IDs as comma separated values. + """ + path = os.path.join(self.output_dir, "connections.txt") + with open(path, "w") as f: + for id1, id2 in self.dataset.merged_ids: + f.write(f"{id1}, {id2}" + "\n") + + def save_model_predictions(self, preds_dict, suffix=""): summary = list() for proposal, pred in preds_dict.items(): # Extract info @@ -428,25 +373,3 @@ def save_proposal_results(self, preds_dict, suffix=""): # Save results path = os.path.join(self.output_dir, f"proposal_summary{suffix}.csv") pd.DataFrame(summary).set_index("Proposal").to_csv(path) - - def reconfigure_node_radius(self): - n_nodes = len(self.dataset.node_radius) - self.dataset.node_radius = np.ones((n_nodes), dtype=np.float16) - for i, j in self.dataset.accepts: - self.dataset.node_radius[i] = 6 - self.dataset.node_radius[j] = 6 - - def save_connections(self): - """ - Writes the accepted proposals to a text file. Each line contains the - two SWC IDs as comma separated values. - """ - path = os.path.join(self.output_dir, "connections.txt") - with open(path, "w") as f: - for id1, id2 in self.dataset.merged_ids: - f.write(f"{id1}, {id2}" + "\n") - - def save_fragment_ids(self): - path = f"{self.output_dir}/segment_ids.txt" - segment_ids = list(self.dataset.component_id_to_swc_id.values()) - util.write_list(path, segment_ids) diff --git a/src/neuron_proofreader/utils/img_util.py b/src/neuron_proofreader/utils/img_util.py index c0972f0..bbee524 100644 --- a/src/neuron_proofreader/utils/img_util.py +++ b/src/neuron_proofreader/utils/img_util.py @@ -203,15 +203,13 @@ def compute_iou3d(c1, c2, s1, s2): return np.prod(overlap) / union if union > 0 else 0 -def find_img_path(bucket_name, root_dir, brain_id): +def find_img_path(root_prefix, brain_id): """ Finds the path to a whole-brain dataset stored in a GCS bucket. Parameters: ---------- - bucket_name : str - Name of the GCS bucket where the images are stored. - root_dir : str + root_prefrix : str Path to the directory in the GCS bucket where the image is expected to be located. dataset_name : str @@ -222,11 +220,12 @@ def find_img_path(bucket_name, root_dir, brain_id): str Path of the found dataset subdirectory within the specified GCS bucket. """ - for subdir in util.list_gcs_subdirectories(bucket_name, root_dir): - if brain_id in subdir: - img_path = f"gs://{bucket_name}/{subdir}whole-brain/fused.zarr" + bucket_name, _ = util.parse_cloud_path(root_prefix) + for prefix in util.list_gcs_subprefixes(root_prefix): + if brain_id in prefix: + img_path = f"gs://{bucket_name}/{prefix}whole-brain/fused.zarr" return img_path - raise f"Dataset not found in {bucket_name} - {root_dir}" + raise f"Dataset not found in {root_prefix}" def get_contained_voxels(voxels, shape, buffer=0): From 9a6195f6d6946e6faed136f254f241fb1d353d2c Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sun, 14 Jun 2026 22:58:42 +0000 Subject: [PATCH 09/10] refactor: heuristic merge, connected soma fragments --- src/neuron_proofreader/proofreading_pipeline.py | 13 +++++++++---- src/neuron_proofreader/skeleton_graph.py | 12 ++++++------ .../split_proofreading/split_inference.py | 2 +- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/neuron_proofreader/proofreading_pipeline.py b/src/neuron_proofreader/proofreading_pipeline.py index 4528978..0cb8f8a 100644 --- a/src/neuron_proofreader/proofreading_pipeline.py +++ b/src/neuron_proofreader/proofreading_pipeline.py @@ -144,6 +144,12 @@ def split_proofreading( self.reconfigure_node_radius() self.save_graph("corrected_swcs") + def connect_soma_fragments(self, max_dist=25): + self.step_cnt += 1 + self.log(f"\nStep {self.step_cnt}: Connect Soma Fragments with dist={max_dist}") + summary = self.graph.connect_soma_fragments(max_dist=max_dist) + self.log(summary) + # --- Merge Proofreading --- def merge_proofreading(self, mode): # Report step @@ -154,13 +160,12 @@ def merge_proofreading(self, mode): # Detect merges if mode == "heuristic": - merge_sites = self.graph.remove_high_risk_merges() + merge_sites, summary = self.graph.remove_high_risk_merges() elif mode == "connected_somas": - results = self.graph.remove_soma_merges() - self.log(results) + merge_sites, summary = self.graph.remove_soma_merges() # Report results - self.log(f"# Merges Detected: {len(merge_sites)}") + self.log(summary) # Save sites color = "# COLOR 1.0 0.0 0.0" diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index 5c4248a..cab8b49 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -317,6 +317,7 @@ def remove_soma_merges(self): component_id_to_soma_nodes[self.node_component_id[i]].add(i) # Check for connected components with multiple soma nodes + merge_sites = list() n_soma_merges = 0 for component_id, soma_nodes in component_id_to_soma_nodes.items(): if len(soma_nodes) > 1 and len(soma_nodes) < 20: @@ -326,16 +327,14 @@ def remove_soma_merges(self): dist, _ = somas_kdtree.query(self.node_xyz[i]) if self.degree[i] > 2 and dist > 25: rm_nodes.add(i) + merge_sites.append(self.node_xyz[i]) self.remove_nearby_nodes(rm_nodes) # Finish self.remove_small_components(relabel_nodes=False) self.relabel_nodes() - results = [ - f"# Soma Fragments: {len(self.soma_centroids)}", - f"# Soma Merges: {n_soma_merges}", - ] - return "\n".join(results) + summary = f"# Soma Merges: {n_soma_merges}" + return merge_sites, summary def remove_nearby_nodes(self, roots, max_dist=5.0): """ @@ -1165,7 +1164,8 @@ def remove_high_risk_merges(self, max_dist=7): # Update graph self.remove_nodes(rm_nodes) - return merge_sites + summary = f"# High Risk Merges: {len(merge_sites)}" + return merge_sites, summary def rooted_subgraph(self, root, radius): """ diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index 89320d1..e9055d3 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -155,7 +155,7 @@ def generate_proposals(self, proposals_config): """ # Main t0 = time() - self.log("\nGenerate Proposals...") + self.log("Generate Proposals...") self.dataset.generate_proposals( proposals_config.search_radius, allow_nonleaf_proposals=proposals_config.allow_nonleaf_proposals, From c9523e1e6b48bf1a70b2a03db52dda19c09be11d Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sun, 14 Jun 2026 23:38:12 +0000 Subject: [PATCH 10/10] refactor: misc --- src/neuron_proofreader/skeleton_graph.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index cab8b49..c53beda 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -296,15 +296,7 @@ def connect_soma_fragments(self, max_dist=25): def remove_soma_merges(self): """ - Breaks a fragment that intersects with multiple somas so that nodes - closest to soma locations are disconnected. - - Parameters - ---------- - swc_dict : dict - Contents of an SWC file. - somas_xyz : List[Tuple[float]] - Physical coordinates representing soma locations. + Removes branching points along paths that connect multiple somas. """ # Check whether to exit if len(self.soma_centroids) == 0: