diff --git a/src/neuron_proofreader/configs.py b/src/neuron_proofreader/configs.py index 90be82e..3a8e108 100644 --- a/src/neuron_proofreader/configs.py +++ b/src/neuron_proofreader/configs.py @@ -1,5 +1,5 @@ """ -Created on Frid Sept 15 16:00:00 2024 +Created on Fri Sept 15 16:00:00 2024 @author: Anna Grim @email: anna.grim@alleninstitute.org @@ -15,6 +15,7 @@ import os +from neuron_proofreader.machine_learning.augmentation import ImageTransforms from neuron_proofreader.utils import util @@ -97,13 +98,14 @@ class ImageConfig(Config): """ brightness_clip: int = 400 + img_path: str = None name: str = "image_config" percentiles: Tuple[float, float] = (1, 99.5) patch_shape: Tuple[int, int, int] = (128, 128, 128) - transform: bool = False + transform = None def set_train_mode(self): - self.transform = True + self.transform = ImageTransforms() @dataclass @@ -124,5 +126,8 @@ class ProposalsConfig(Config): """ allow_nonleaf_proposals: bool = False - proposals_per_leaf: int = 3 + max_proposals_per_leaf: int = 3 + min_size_with_proposals: float = 0 trim_endpoints_bool: bool = True + search_radius: float = 25 + search_scaling_factor: float = 1.5 diff --git a/src/neuron_proofreader/geometric_learning/curve_augmentation.py b/src/neuron_proofreader/geometric_learning/curve_augmentation.py new file mode 100644 index 0000000..e569a95 --- /dev/null +++ b/src/neuron_proofreader/geometric_learning/curve_augmentation.py @@ -0,0 +1,211 @@ +""" +Created on Wed June 11 12:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code for applying data augmentation to 3D space curves. + +""" + +import numpy as np +import random + + +class CurveTransforms: + """ + Class that applies a sequence of transforms to a 3D space curve. + """ + + def __init__(self): + """ + Initializes a CurveTransforms instance that applies augmentation to + a 3D space curve. + """ + self.transforms = [ + RandomRotation3D(), + RandomMirror3D(), + RandomJitter3D(), + ] + + def __call__(self, curve): + """ + Applies transforms to the input curve. + + Parameters + ---------- + curve : numpy.ndarray + Array of shape (N, 3) representing N points in 3D space (x, y, z). + """ + # Check whether to reverse path + if random.random() > 0.5: + curve = np.flip(curve) + + # Apply transforms + for transform in self.transforms: + curve = transform(curve) + return curve + + +# --- Noise Transforms --- +class RandomJitter3D: + """ + Randomly adds Gaussian noise to each point in a 3D curve. + """ + + def __init__(self, sigma=0.1, p=0.5): + """ + Initializes a RandomJitter3D transformer. + + Parameters + ---------- + sigma : float, optional + Standard deviation of the Gaussian noise. Default is 0.01. + p : float, optional + Probability of applying the transform. Default is 0.5. + """ + self.sigma = sigma + self.p = p + + def __call__(self, curve): + """ + Applies random jitter to the input curve. + + Parameters + ---------- + curve : numpy.ndarray + Array of shape (N, 3) representing N points in 3D space. + + Returns + ------- + numpy.ndarray + Jittered curve of shape (N, 3). + """ + if random.random() > self.p: + return curve + noise = np.random.normal(0, self.sigma, size=curve.shape) + return curve + noise + + +# --- Geometric Transforms --- +class RandomRotation3D: + """ + Applies a random 3D rotation to a curve about a random axis. + """ + + def __init__(self, max_angle=np.pi, p=0.5): + """ + Initializes a RandomRotation3D transformer. + + Parameters + ---------- + max_angle : float, optional + Maximum rotation angle in radians. Default is pi (full rotation). + p : float, optional + Probability of applying the transform. Default is 0.5. + """ + self.max_angle = max_angle + self.p = p + + def _rotation_matrix(self, axis, angle): + """ + Computes the Rodrigues rotation matrix for a given axis and angle. + + Parameters + ---------- + axis : numpy.ndarray + Unit vector of shape (3,) representing the rotation axis. + angle : float + Rotation angle in radians. + + Returns + ------- + numpy.ndarray + Rotation matrix of shape (3, 3). + """ + c, s = np.cos(angle), np.sin(angle) + x, y, z = axis + return np.array( + [ + [ + c + x * x * (1 - c), + x * y * (1 - c) - z * s, + x * z * (1 - c) + y * s, + ], + [ + y * x * (1 - c) + z * s, + c + y * y * (1 - c), + y * z * (1 - c) - x * s, + ], + [ + z * x * (1 - c) - y * s, + z * y * (1 - c) + x * s, + c + z * z * (1 - c), + ], + ] + ) + + def __call__(self, curve): + """ + Applies a random rotation to the input curve about its centroid. + + Parameters + ---------- + curve : numpy.ndarray + Array of shape (N, 3) representing N points in 3D space. + + Returns + ------- + numpy.ndarray + Rotated curve of shape (N, 3). + """ + if random.random() > self.p: + return curve + axis = np.random.randn(3) + axis /= np.linalg.norm(axis) + angle = random.uniform(-self.max_angle, self.max_angle) + R = self._rotation_matrix(axis, angle) + centroid = curve.mean(axis=0) + return (curve - centroid) @ R.T + centroid + + +class RandomMirror3D: + """ + Randomly mirrors a 3D curve along one or more axes about its centroid. + """ + + def __init__(self, axes=(0, 1, 2), p=0.5): + """ + Initializes a RandomMirror3D transformer. + + Parameters + ---------- + axes : Tuple[int], optional + Axes to consider for mirroring. Default is (0, 1, 2). + p : float, optional + Per-axis probability of mirroring. Default is 0.5. + """ + self.axes = axes + self.p = p + + def __call__(self, curve): + """ + Applies random mirroring to the input curve. + + Parameters + ---------- + curve : numpy.ndarray + Array of shape (N, 3) representing N points in 3D space. + + Returns + ------- + numpy.ndarray + Mirrored curve of shape (N, 3). + """ + curve = curve.copy() + centroid = curve.mean(axis=0) + for axis in self.axes: + if random.random() > self.p: + continue + curve[:, axis] = 2 * centroid[axis] - curve[:, axis] + return curve diff --git a/src/neuron_proofreader/geometric_learning/curve_datamodules.py b/src/neuron_proofreader/geometric_learning/curve_datamodules.py new file mode 100644 index 0000000..e957f44 --- /dev/null +++ b/src/neuron_proofreader/geometric_learning/curve_datamodules.py @@ -0,0 +1,264 @@ +""" +Created on Mon June 8 17:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +... + +""" + +from copy import deepcopy +from torch.utils.data import Dataset, DataLoader, Sampler + +import networkx as nx +import numpy as np +import pandas as pd +import torch + +from neuron_proofreader.skeleton_graph import SkeletonGraph + + +# --- Dataset Classes --- +class PathsDataset(Dataset): + + def __init__( + self, + brain_id, + swcs_path, + graph_config=None, + max_length=np.inf, + transform=None, + ): + # Instance attributes + self.brain_id = brain_id + self.max_length = max_length + self.transform = transform + + # Core data structures + self.graph = self.load_skeletons(graph_config, swcs_path) + self.paths = self.get_valid_paths() + + def load_skeletons(self, config, swcs_path): + graph = SkeletonGraph( + anisotropy=config.anisotropy, + min_cable_length=config.min_cable_length, + node_spacing=config.node_spacing, + use_anisotropy=config.use_anisotropy, + verbose=config.verbose, + ) + graph.load(swcs_path) + return graph + + def get_valid_paths(self): + paths = list() + for p in self.irreducible_paths(): + if self.path_length(p) < self.max_length: + paths.append(p) + return paths + + # --- Get Examples --- + def __getitem__(self, i): + # Get path + curve = deepcopy(self.node_xyz[self.paths[i]]) + if self.transform: + curve = self.transform(curve) + + # Normalize + curve -= curve[0] + curve[1:] -= curve[:-1] + return curve + + # --- Helpers --- + def path_lengths(self): + return np.array([self.path_length(p) for p in self.paths]) + + def __getattr__(self, name): + return getattr(self.graph, name) + + def __len__(self): + return len(self.paths) + + def __repr__(self): + lengths = self.path_lengths() + num_neurons = nx.number_connected_components(self.graph) + return ( + f"BrainDataset(" + f"\n brain_id={self.brain_id}, " + f"\n num_neurons={num_neurons}, " + f"\n num_paths={len(self)}, " + f"\n min_length={np.min(lengths):.2f}, " + f"\n mean_length={np.mean(lengths):.2f}, " + f"\n max_length={np.max(lengths):.2f}," + f"\n)" + ) + + +class PathsDatasetCollection(Dataset): + + def __init__(self, datasets, is_val=False, n_val_examples=1000, seed=42): + """ + Parameters + ---------- + datasets : List[PathsDataset] + List of PathsDataset instances, one per brain. + is_val : bool, optional + If True, precomputes a fixed set of examples at construction time. + Default is False. + n_val_examples : int, optional + Number of fixed validation examples to precompute. Default is 1000. + seed : int, optional + Random seed for reproducible val set. Default is 42. + """ + # Instance attributes + self.datasets = datasets + self.is_val = is_val + self.set_examples_df() + + # Check whether to set validation examples + if is_val: + self.val_examples = self.set_val_examples(n_val_examples, seed) + + def set_examples_df(self): + rows = [] + for ds_idx, dataset in enumerate(self.datasets): + ds_idxs = np.full(len(dataset), ds_idx) + p_idxs = np.arange(len(dataset)) + ds_lengths = dataset.path_lengths() + rows.append( + pd.DataFrame( + { + "ds_idx": ds_idxs, + "path_idx": p_idxs, + "length": ds_lengths, + } + ) + ) + self.examples_df = pd.concat(rows, ignore_index=True) + + def set_val_examples(self, n, seed): + """ + Samples n examples with fixed seed, strips transforms, and caches + the resulting examples. + """ + rng = np.random.default_rng(seed) + indices = rng.choice(len(self.examples_df), size=n, replace=False) + examples = [] + for i in indices: + ds_idx = self.examples_df["ds_idx"][i] + path_idx = self.examples_df["path_idx"][i] + dataset = self.datasets[ds_idx] + examples.append(dataset[path_idx]) + return examples + + # --- Data Fetching --- + def __getitem__(self, i): + # Case 1: validation example + if self.is_val: + return self.val_examples[i] + + # Case 2: train example + ds_idx = self.examples_df["ds_idx"][i] + path_idx = self.examples_df["path_idx"][i] + return self.datasets[ds_idx][path_idx] + + def __len__(self): + if self.is_val: + return len(self.val_examples) + return len(self.examples_df) + + def __repr__(self): + return ( + f"PathsDatasetCollection(" + f"num_brains={len(self.datasets)}, " + f"num_paths={len(self.examples_df)}) " + ) + + +# --- DataLoader Classes --- +class PathSampler(Sampler): + + def __init__(self, dataset, examples_per_epoch): + """ + Parameters + ---------- + dataset : PathsDatasetCollection + Dataset to sample from. + """ + self.dataset = dataset + self.examples_per_epoch = examples_per_epoch + + def __iter__(self): + idxs = self.dataset.examples_df.sample( + self.examples_per_epoch, replace=True, weights="length" + ).index + return iter(np.array(idxs)) + + def __len__(self): + return self.examples_per_epoch + + +def collate_curves(curves): + """ + Pads a list of curves to the longest in the batch and generates an + attention mask. + + Parameters + ---------- + curves : List[numpy.ndarray] + Each of shape (N_i, 3), where N_i can vary. + + Returns + ------- + padded : torch.Tensor + Shape (B, N_max, 3), zero-padded. + mask : torch.Tensor + Shape (B, N_max), True where padding. + """ + lengths = [len(c) for c in curves] + n_max = max(lengths) + B = len(curves) + + padded = torch.zeros(B, n_max, 3) + mask = torch.ones(B, n_max, dtype=torch.bool) + for i, (c, l) in enumerate(zip(curves, lengths)): + padded[i, :l] = torch.tensor(c) + mask[i, :l] = False + return padded, mask + + +def build_dataloader( + dataset, + batch_size=32, + examples_per_epoch=5000, + num_workers=0, + use_sampler=True, +): + """ + Builds a DataLoader for a PathsDatasetCollection that samples uniformly + across all non-empty bins. + + Parameters + ---------- + dataset : PathsDatasetCollection + Dataset to load from. + examples_per_epoch : int, optional + Number of examples per epoch. Default is 5000. + batch_size : int, optional + Number of curves per batch. Default is 32. + num_workers : int, optional + Number of worker processes for data loading. Default is 0. + + Returns + ------- + DataLoader + """ + sampler = PathSampler(dataset, examples_per_epoch) if use_sampler else None + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=collate_curves, + num_workers=num_workers, + pin_memory=True, + sampler=sampler, + ) diff --git a/src/neuron_proofreader/geometric_learning/curve_transformer.py b/src/neuron_proofreader/geometric_learning/curve_transformer.py new file mode 100644 index 0000000..45cc0da --- /dev/null +++ b/src/neuron_proofreader/geometric_learning/curve_transformer.py @@ -0,0 +1,366 @@ +""" +Created on Wed June 10 12:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +... + +""" + +import numpy as np +import torch +import torch.nn as nn + +from neuron_proofreader.utils import util + + +class CurveEncoder(nn.Module): + """ + Transformer encoder that maps a 3D space curve, normalized to the unit + sphere, to a fixed-size latent vector. The curve is tokenized into fixed- + length segments with a fixed learned token for the start point at the + origin and a projected token for the end point. Positional encodings are + sinusoidal over the normalized arc position [0, 1], making the encoder + robust to varying numbers of points and path lengths. + """ + + def __init__( + self, + segment_len=10, + d_token=64, + n_heads=4, + n_layers=4, + d_ff=64, + latent_dim=32, + dropout=0.1, + ): + """ + Parameters + ---------- + segment_len : int + Number of points per segment token. + d_token : int + Dimension of each token. + n_heads : int + Number of attention heads. + n_layers : int + Number of transformer encoder layers. + d_ff : int + Feed-forward hidden dimension. + latent_dim : int + Dimension of the output latent vector. + dropout : float + Dropout probability. + """ + # Call parent class + super().__init__() + + # Instance attributes + self.segment_len = segment_len + self.start_token = nn.Parameter(torch.randn(1, 1, d_token)) + self.end_token_proj = nn.Linear(3, d_token) + self.segment_proj = nn.Linear(segment_len * 3, d_token) + + # Archictecture + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_token, + nhead=n_heads, + dim_feedforward=d_ff, + dropout=dropout, + batch_first=True, + ) + self.transformer = nn.TransformerEncoder( + encoder_layer, num_layers=n_layers + ) + self.to_latent = nn.Sequential( + nn.LayerNorm(d_token), + nn.Linear(d_token, latent_dim), + ) + + def forward(self, offsets, mask=None): + """ + Parameters + ---------- + offsets : torch.Tensor + Shape (B, N, 3), normalized to the unit sphere, with + offsets[:, 0] == [0, 0, 0]. N can vary across calls. + mask : torch.Tensor, optional + Shape (B, N), True where padding (to be ignored). Default is None. + + Returns + ------- + z : torch.Tensor + Latent vector of shape (B, latent_dim). + tokens : torch.Tensor + Per-token encodings of shape (B, n_segments + 2, d_token). + """ + B, N, _ = offsets.shape + n_segments = N // self.segment_len + + # Start and end tokens + start_tok = self.start_token.expand(B, -1, -1) # (B, 1, d_token) + end_tok = self.end_token_proj(offsets[:, -1, :]).unsqueeze( + 1 + ) # (B, 1, d_token) + + # Segment tokens + segments = offsets[:, : n_segments * self.segment_len, :] + segments = segments.reshape(B, n_segments, self.segment_len * 3) + seg_tokens = self.segment_proj(segments) # (B, n_seg, d_token) + + # Concatenate: [start | segments | end] + tokens = torch.cat( + [start_tok, seg_tokens, end_tok], dim=1 + ) # (B, n_seg+2, d_token) + + # Convert point-level mask to token-level mask + token_mask = None + if mask is not None: + seg_mask = mask[:, :: self.segment_len][ + :, :n_segments + ] # (B, n_seg) + token_mask = torch.cat( + [ + torch.zeros(B, 1, dtype=torch.bool, device=mask.device), + seg_mask, + torch.zeros(B, 1, dtype=torch.bool, device=mask.device), + ], + dim=1, + ) # (B, n_seg+2) + + # Sinusoidal positional encoding, zeroed out for padding tokens + pe = sinusoidal_encoding( + tokens.shape[1], tokens.shape[2], tokens.device + ) + if token_mask is not None: + pe = pe * (~token_mask).unsqueeze(-1).float() + tokens = tokens + pe + + tokens = self.transformer(tokens, src_key_padding_mask=token_mask) + + # Mean pool over non-padding tokens only + if token_mask is not None: + valid = (~token_mask).unsqueeze(-1).float() + z = self.to_latent((tokens * valid).sum(dim=1) / valid.sum(dim=1)) + else: + z = self.to_latent(tokens.mean(dim=1)) + + return z, tokens + + +class CurveDecoder(nn.Module): + """ + Transformer decoder that reconstructs a 3D space curve from a latent + vector and the encoder's token representations. Positional queries are + sinusoidally encoded over arc position and biased by the global latent. + The output resolution can differ from the encoder input, allowing + decoding at arbitrary granularity. + """ + + def __init__( + self, + n_points=100, + segment_len=10, + d_token=64, + n_heads=4, + n_layers=4, + d_ff=64, + latent_dim=32, + dropout=0.1, + ): + """ + Parameters + ---------- + n_points : int + Default number of output curve points. + segment_len : int + Number of points per segment token (must match encoder). + d_token : int + Dimension of each token throughout the transformer. + n_heads : int + Number of attention heads. + n_layers : int + Number of transformer decoder layers. + d_ff : int + Feed-forward hidden dimension. + latent_dim : int + Dimension of the input latent vector. + dropout : float + Dropout probability. + """ + super().__init__() + self.segment_len = segment_len + self.n_segments = n_points // segment_len + + # Project latent to d_token to bias the positional queries + self.latent_proj = nn.Linear(latent_dim, d_token) + + decoder_layer = nn.TransformerDecoderLayer( + d_model=d_token, + nhead=n_heads, + dim_feedforward=d_ff, + dropout=dropout, + batch_first=True, + ) + self.transformer = nn.TransformerDecoder( + decoder_layer, num_layers=n_layers + ) + + self.to_points = nn.Sequential( + nn.LayerNorm(d_token), + nn.Linear(d_token, segment_len * 3), + ) + + def forward(self, z, encoder_tokens, encoder_mask=None, n_segments=None): + """ + Parameters + ---------- + z : torch.Tensor + Latent vector of shape (B, latent_dim). + encoder_tokens : torch.Tensor + Per-token encoder outputs of shape (B, n_segments + 2, d_token). + encoder_mask : torch.Tensor, optional + Shape (B, n_segments + 2), True where padding. Passed as + memory_key_padding_mask to cross-attention. Default is None. + n_segments : int, optional + Number of output segments. Inferred from encoder tokens if not + provided. + + Returns + ------- + curve : torch.Tensor + Reconstructed curve of shape (B, n_segments * segment_len, 3). + """ + B = z.shape[0] + d_token = encoder_tokens.shape[2] + n_segments = n_segments or self.n_segments + + # Sinusoidal queries over arc position, biased by global latent + pe = sinusoidal_encoding(n_segments, d_token, encoder_tokens.device) + latent = self.latent_proj(z).unsqueeze(1) # (B, 1, d_token) + queries = pe.expand(B, -1, -1) + latent # (B, n_seg, d_token) + + out = self.transformer( + queries, + encoder_tokens, + memory_key_padding_mask=encoder_mask, + ) # (B, n_seg, d_token) + + segments = self.to_points(out) # (B, n_seg, seg_len*3) + offsets = segments.reshape(B, n_segments * self.segment_len, 3) + return offsets + + +class CurveAutoencoder(nn.Module): + + def __init__( + self, + n_points=100, + segment_len=10, + d_token=64, + n_heads=4, + n_layers=4, + d_ff=64, + latent_dim=32, + dropout=0.1, + ): + # Call parent class + super().__init__() + + # Config + self.config = { + "n_points": n_points, + "segment_len": segment_len, + "d_token": d_token, + "n_heads": n_heads, + "n_layers": n_layers, + "d_ff": d_ff, + "latent_dim": latent_dim, + "dropout": dropout, + } + + # Architecture + self.encoder = CurveEncoder( + segment_len=segment_len, + d_token=d_token, + n_heads=n_heads, + n_layers=n_layers, + d_ff=d_ff, + latent_dim=latent_dim, + dropout=dropout, + ) + self.decoder = CurveDecoder( + n_points=n_points, + segment_len=segment_len, + d_token=d_token, + n_heads=n_heads, + n_layers=n_layers, + d_ff=d_ff, + latent_dim=latent_dim, + dropout=dropout, + ) + + def forward(self, offsets, token_mask): + """ + Parameters + ---------- + offsets : torch.Tensor + Shape (B, N, 3), normalized to the unit sphere, offsets[:, 0] == 0. + + Returns + ------- + reconstruction : torch.Tensor + Shape (B, N, 3). + z : torch.Tensor + Latent vector of shape (B, latent_dim). + """ + z, encoder_tokens = self.encoder(offsets, token_mask) + n_segments = offsets.shape[1] // self.decoder.segment_len + reconstruction = self.decoder(z, encoder_tokens, n_segments=n_segments) + return reconstruction, z + + def encode(self, offsets): + z, _ = self.encoder(offsets) + return z + + # --- Helpers --- + def save_config(self, path): + util.write_json(path, self.config) + + @classmethod + def load(cls, path): + checkpoint = torch.load(path) + model = cls(**checkpoint["config"]) + model.load_state_dict(checkpoint["model_state_dict"]) + return model + + +# --- Helpers --- +def sinusoidal_encoding(n_tokens, d_token, device): + """ + Sinusoidal positional encoding over normalised arc position [0, 1]. + + Parameters + ---------- + n_tokens : int + Number of tokens (segments + 2 endpoint tokens). + d_token : int + Model dimension. + device : torch.device + Device to create the encoding on. + + Returns + ------- + torch.Tensor + Encoding of shape (1, n_tokens, d_token). + """ + position = torch.linspace(0, 1, n_tokens, device=device).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_token, 2, device=device) + * (-np.log(10000.0) / d_token) + ) + pe = torch.zeros(n_tokens, d_token, device=device) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + return pe.unsqueeze(0) diff --git a/src/neuron_proofreader/machine_learning/image_augmentation.py b/src/neuron_proofreader/machine_learning/augmentation.py similarity index 83% rename from src/neuron_proofreader/machine_learning/image_augmentation.py rename to src/neuron_proofreader/machine_learning/augmentation.py index 93b2d72..5cb974a 100644 --- a/src/neuron_proofreader/machine_learning/image_augmentation.py +++ b/src/neuron_proofreader/machine_learning/augmentation.py @@ -14,6 +14,7 @@ import random +# --- Image Augmentation --- class ImageTransforms: """ Class that applies a sequence of transforms to a 3D image and segmentation @@ -42,7 +43,7 @@ def __call__(self, patches): ---------- patches : numpy.ndarray Image with the shape (2, H, W, D), where the first channel is the - input image and second is the segmentation. + raw image and second is a mask. """ for transform in self.transforms: patches = transform(patches) @@ -73,8 +74,8 @@ def __call__(self, patches): Parameters ---------- patches : numpy.ndarray - Image with the shape (2, H, W, D), where "patches[0, ...]" is from - the input image and "patches[1, ...]" is from the segmentation. + Image with the shape (2, H, W, D), where the first channel is the + raw image and second is a mask. """ for axis in self.axes: if random.random() > 0.5: @@ -95,7 +96,7 @@ def __init__(self, angles=(-90, 90), axes=((0, 1), (0, 2), (1, 2))): Parameters ---------- angles : Tuple[int], optional - Maximum angle of rotation. Default is (-45, 45). + Maximum angle of rotation. Default is (-90, 90). axis : Tuple[Tuple[int]], optional Axes to apply rotation. Default is ((0, 1), (0, 2), (1, 2)) """ @@ -109,8 +110,8 @@ def __call__(self, patches): Parameters ---------- patches : numpy.ndarray - Image with the shape (2, H, W, D), where "patches[0, ...]" is from - the input image and "patches[1, ...]" is from the segmentation. + Image with the shape (2, H, W, D), where the first channel is the + raw image and second is a mask. """ for axes in self.axes: if random.random() < 0.5: @@ -120,7 +121,7 @@ def __call__(self, patches): return patches @staticmethod - def rotate3d(img_patch, angle, axes, is_segmentation=False): + def rotate3d(img_patch, angle, axes, is_mask=False): """ Rotates a 3D image patch around the specified axes by a given angle. @@ -132,13 +133,12 @@ def rotate3d(img_patch, angle, axes, is_segmentation=False): Angle (in degrees) by which to rotate the image patch around the specified axes. axes : Tuple[int] - Tuple representing the two axes of rotation. - is_segmentation : bool, optional - Indication of whether the image is a segmentation. Default is - False. + Two axes of rotation. + is_mask : bool, optional + True if the image is a mask. """ - order = 0 if is_segmentation else 3 - multipler = 4 if is_segmentation else 1 + order = 0 if is_mask else 3 + multipler = 4 if is_mask else 1 img_patch = rotate( multipler * img_patch, angle, @@ -174,8 +174,8 @@ def __call__(self, patches): Parameters ---------- patches : numpy.ndarray - Image with the shape (2, H, W, D), where "patches[0, ...]" is from - the input image and "patches[1, ...]" is from the segmentation. + Image with the shape (2, H, W, D), where the first channel is the + raw image and second is a mask. Returns ------- @@ -226,8 +226,8 @@ def __call__(self, patches): Parameters ---------- patches : numpy.ndarray - Image with the shape (2, H, W, D), where the zeroth channel is - from the raw image and first channel is from the segmentation. + Image with the shape (2, H, W, D), where the first channel is the + raw image and second is a mask. """ lo = np.percentile(patches[0], np.random.uniform(*self.p_low)) hi = np.percentile(patches[0], np.random.uniform(*self.p_high)) @@ -260,10 +260,13 @@ def __call__(self, patches): Parameters ---------- patches : numpy.ndarray - Image with the shape (2, H, W, D), where "patches[0, ...]" is from - the input image and "patches[1, ...]" is from the segmentation. + Image with the shape (2, H, W, D), where the first channel is the + raw image and second is a mask. """ std = self.max_std * random.random() patches[0] += np.random.uniform(-std, std, patches[0].shape) patches[0] = np.clip(patches[0], 0, 1) return patches + + +# --- 3D Space Curve Augmentation --- diff --git a/src/neuron_proofreader/machine_learning/gnn_models.py b/src/neuron_proofreader/machine_learning/gnn_models.py index b5add00..95f1387 100644 --- a/src/neuron_proofreader/machine_learning/gnn_models.py +++ b/src/neuron_proofreader/machine_learning/gnn_models.py @@ -39,7 +39,7 @@ def __init__( self, patch_shape, disable_msg_passing=False, - heads=2, + heads=4, hidden_dim=128, n_layers=2, ): diff --git a/src/neuron_proofreader/machine_learning/image_dataloader.py b/src/neuron_proofreader/machine_learning/image_dataloader.py index a0c1834..554be34 100644 --- a/src/neuron_proofreader/machine_learning/image_dataloader.py +++ b/src/neuron_proofreader/machine_learning/image_dataloader.py @@ -14,7 +14,7 @@ import tensorstore as ts from neuron_proofreader.configs import ImageConfig -from neuron_proofreader.machine_learning.image_augmentation import ( +from neuron_proofreader.machine_learning.augmentation import ( ImageTransforms, ) from neuron_proofreader.utils import geometry_util, img_util, util @@ -234,7 +234,7 @@ def create_mask(self, center, shape, node): # Annotate mask mask = np.zeros(shape) - #self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP + # self.annotate_foreground(mask, nodes, offset, fill_val=0.5) TEMP self.annotate_fragment(mask, subgraph, offset, fill_val=1) return mask diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index e0cca48..9c65b4a 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -66,6 +66,7 @@ def __init__( model_name, output_dir, device="cuda", + exp_name=None, lr=1e-3, max_epochs=200, min_recall=0, @@ -94,7 +95,8 @@ def __init__( Indication of whether to save MIPs of mistakes. Default is False. """ # Set experiment name - exp_name = "session-" + datetime.today().strftime("%Y%m%d_%H%M") + if exp_name is None: + exp_name = "session-" + datetime.today().strftime("%Y%m%d_%H%M") log_dir = os.path.join(output_dir, exp_name) util.mkdir(log_dir) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py index d9b336a..3e97fd7 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datamodules.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datamodules.py @@ -132,7 +132,7 @@ def set_merge_site_info(self): # Store fragment IDs corresponding to merge sites for xyz in self.merge_sites["xyz"]: _, ii = self.kdtree.query(xyz) - #self.ignore_fragments.add(self.node_component_id[ii]) + # self.ignore_fragments.add(self.node_component_id[ii]) def set_giant_components(self): for nodes in map(list, nx.connected_components(self.graph)): @@ -176,7 +176,9 @@ def get_random_nonmerge_site(self): # Try again n_attempts += 1 if n_attempts > 100: - print(f"Failed to find valid random nonmerge site for {self.brain_id}!") + print( + f"Failed to find valid random nonmerge site for {self.brain_id}!" + ) return util.sample_once(self.nodes) # --- Helpers --- @@ -324,21 +326,25 @@ def save_val_summary(self, output_dir): # Merge sites for _, row in brain_dataset.merge_sites.iterrows(): - rows.append({ - "brain_id": brain_id, - "swc_name": row["filename"], - "xyz": row["xyz"], - "label": "merge", - }) + rows.append( + { + "brain_id": brain_id, + "swc_name": row["filename"], + "xyz": row["xyz"], + "label": "merge", + } + ) # Nonmerge sites for _, row in brain_dataset.nonmerge_sites.iterrows(): - rows.append({ - "brain_id": brain_id, - "swc_name": row["filename"], - "xyz": row["xyz"], - "label": "nonmerge", - }) + rows.append( + { + "brain_id": brain_id, + "swc_name": row["filename"], + "xyz": row["xyz"], + "label": "nonmerge", + } + ) df = pd.DataFrame(rows) df.to_csv(os.path.join(output_dir, "val_summary.csv"), index=False) @@ -439,7 +445,7 @@ def __iter__(self): # Split into batches upfront batch_idx_groups = [ - idxs[start: min(start + self.batch_size, len(idxs))] + idxs[start : min(start + self.batch_size, len(idxs))] for start in range(0, len(idxs), self.batch_size) ] @@ -631,10 +637,12 @@ def create_dataset_collection( img_path = os.path.join(img_prefixes[brain_id], "0") segmentation_id = get_segmentation_id(sites_root_path, brain_id) sites_path = os.path.join(sites_root_path, brain_id, segmentation_id) - swcs_path = os.path.join(swcs_root_path, brain_id, segmentation_id, "fragments") - #util.get_google_swcs_prefix( + swcs_path = os.path.join( + swcs_root_path, brain_id, segmentation_id, "fragments" + ) + # util.get_google_swcs_prefix( # swcs_root_path, brain_id, segmentation_id - #) + # ) # Add dataset print(f" \nBrain ID [{i}/{len(brain_ids)}]: {brain_id}") diff --git a/src/neuron_proofreader/proofreading_pipeline.py b/src/neuron_proofreader/proofreading_pipeline.py new file mode 100644 index 0000000..0cb8f8a --- /dev/null +++ b/src/neuron_proofreader/proofreading_pipeline.py @@ -0,0 +1,213 @@ +""" +Created on Fri June 13 16:00:00 2026 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code for running full neuron proofreading pipeline, including both split and +merge detection and correction. + +""" + +from time import time + +import numpy as np +import os + +from neuron_proofreader.proposal_graph import ProposalGraph +from neuron_proofreader.split_proofreading.split_inference import ( + SplitProofreader, +) +from neuron_proofreader.utils import geometry_util, swc_util, util + + +class ProofreadPipeline: + + def __init__( + self, + swcs_path, + graph_config, + img_config, + output_dir, + device="cuda", + log_preamble="", + soma_centroids=list(), + ): + """ + Initializes an object that executes the full split proofreading + pipeline. + + Parameters + ---------- + swcs_path : str + Path to SWC files to be loaded into graph. + graph_config : GraphConfig + Config object that contains parameters for building graph. + img_config : ImageConfig + Config object that contains parameters for processing images. + output_dir : str + Directory where the results of the inference will be saved. + log_preamble : str, optional + String to be added to the beginning of log. Default is an empty + string. + soma_centroids : List[Tuple[float]], optional + Physical coordinates of soma centroids. Default is an empty list. + """ + # Instance attributes + self.device = device + self.img_config = img_config + self.output_dir = output_dir + self.step_cnt = 0 + + # Logger + util.mkdir(self.output_dir) + log_path = os.path.join(self.output_dir, "summary.txt") + self.log_handle = open(log_path, "a") + self.log(log_preamble) + + # Load data + self.load_graph(graph_config, swcs_path, soma_centroids) + + def load_graph(self, config, swcs_path, soma_centroids): + """ + Loads a graph from the given fragments. + + Parameters + ---------- + swcs_path : str + Path to SWC files to be loaded into graph. + config : GraphConfig + Configuration object that contains parameters for building graph. + """ + # Load data + t0 = time() + self.step_cnt += 1 + self.log(f"Step {self.step_cnt}: Build Graph") + self.graph = ProposalGraph( + anisotropy=config.anisotropy, + min_cable_length=config.min_cable_length, + node_spacing=config.node_spacing, + verbose=config.verbose, + ) + self.graph.load(swcs_path) + self.graph.load_somas(soma_centroids) + + # Remove doubled fragments + if config.remove_doubles: + geometry_util.remove_doubles(self.graph, 200) + + # Save original graph state + self.save_graph("original_swcs") + self.log("\nInitial Graph...") + self.log(self.graph.__repr__()) + + # Report runtime + elapsed, unit = util.time_writer(time() - t0) + self.log(f"Module Runtime: {elapsed:.2f} {unit}\n") + + # --- Split Proofreading --- + def split_proofreading( + self, + model, + proposals_config, + batch_size=32, + dt=0.05, + min_threshold=0.8, + removal_threshold=0.3, + save_result=True, + ): + # Create proofreader + proofreader = SplitProofreader( + self.graph, + model, + self.img_config, + self.output_dir, + batch_size=batch_size, + device=self.device, + log_handle=self.log_handle, + ) + + # Run inference + self.step_cnt += 1 + self.log(f"\nStep {self.step_cnt}: Split Proofreading") + proofreader( + proposals_config, + dt=dt, + min_threshold=min_threshold, + removal_threshold=removal_threshold, + ) + + # Save final graph + if save_result: + self.log("Final Graph...") + self.log(self.graph.__repr__()) + self.reconfigure_node_radius() + self.save_graph("corrected_swcs") + + def connect_soma_fragments(self, max_dist=25): + self.step_cnt += 1 + self.log(f"\nStep {self.step_cnt}: Connect Soma Fragments with dist={max_dist}") + summary = self.graph.connect_soma_fragments(max_dist=max_dist) + self.log(summary) + + # --- Merge Proofreading --- + def merge_proofreading(self, mode): + # Report step + self.step_cnt += 1 + self.log( + f"\nStep {self.step_cnt}: Merge Proofreading with mode={mode}" + ) + + # Detect merges + if mode == "heuristic": + merge_sites, summary = self.graph.remove_high_risk_merges() + elif mode == "connected_somas": + merge_sites, summary = self.graph.remove_soma_merges() + + # Report results + self.log(summary) + + # Save sites + color = "# COLOR 1.0 0.0 0.0" + zip_path = os.path.join(self.output_dir, f"{mode}_merge_sites.zip") + swc_util.write_points( + zip_path, merge_sites, color=color, prefix="merge_site", radius=10 + ) + + # --- Helpers --- + def log(self, txt): + """ + Logs and prints the given text. + + Parameters + ---------- + txt : str + Text to be logged and printed. + """ + print(txt) + self.log_handle.write(txt) + self.log_handle.write("\n") + + def reconfigure_node_radius(self): + n_nodes = len(self.graph.node_radius) + self.graph.node_radius = np.ones((n_nodes), dtype=np.float16) + for i, j in self.graph.accepts: + self.graph.node_radius[i] = 6 + self.graph.node_radius[j] = 6 + + def save_fragment_ids(self): + path = f"{self.output_dir}/segment_ids.txt" + segment_ids = list(self.graph.component_id_to_swc_id.values()) + util.write_list(path, segment_ids) + + def save_graph(self, dirname): + # Save graph across set of ZIPs + temp_dir = os.path.join(self.output_dir, "temp") + self.graph.to_zipped_swcs_multithreaded(temp_dir) + + # Combine ZIPs into single ZIP + zip_paths = util.list_paths(temp_dir, extension=".zip") + final_zip_path = os.path.join(self.output_dir, dirname, "swcs.zip") + util.mkdir(os.path.join(self.output_dir, dirname)) + util.combine_zips(zip_paths, final_zip_path) + util.rmdir(temp_dir) diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index 7a9616e..f4bf813 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -22,7 +22,7 @@ ProposalGenerator, trim_proposal_endpoints, ) -from neuron_proofreader.utils import geometry_util, graph_util +from neuron_proofreader.utils import geometry_util class ProposalGraph(SkeletonGraph): @@ -37,7 +37,6 @@ def __init__( self, anisotropy=(1.0, 1.0, 1.0), gt_path=None, - max_proposals_per_leaf=3, min_cable_length=0, node_spacing=1, prune_depth=20.0, @@ -63,36 +62,22 @@ def __init__( graph. Default is True. """ # Call parent class - super().__init__() - - # Instance attributes - Graph - self.anisotropy = anisotropy - self.component_id_to_swc_id = dict() - self.gt_path = gt_path - self.soma_component_ids = set() - self.verbose = verbose + super().__init__( + anisotropy=anisotropy, + min_cable_length=min_cable_length, + node_spacing=node_spacing, + prune_depth=prune_depth, + verbose=verbose, + ) # Instance attributes - Proposals self.accepts = set() self.gt_accepts = set() + self.gt_path = gt_path self.merged_ids = set() self.n_merges_blocked = 0 self.n_proposals_blocked = 0 - self.reset_proposals() - self.proposal_generator = ProposalGenerator( - self, - max_proposals_per_leaf=max_proposals_per_leaf, - ) - - # Graph Loader - self.graph_loader = graph_util.GraphLoader( - anisotropy=anisotropy, - min_cable_length=min_cable_length, - node_spacing=node_spacing, - prune_depth=prune_depth, - verbose=verbose, - ) # --- Update Structure --- def relabel_nodes(self): @@ -135,6 +120,7 @@ def generate_proposals( self, search_radius, allow_nonleaf_proposals=False, + max_proposals_per_leaf=3, min_size_with_proposals=0, ): """ @@ -153,11 +139,13 @@ def generate_proposals( """ # Proposal generation assert len(self.kdtree.data) == self.number_of_nodes() - proposals = self.proposal_generator( - search_radius, + proposal_generator = ProposalGenerator( + self, allow_nonleaf_proposals=allow_nonleaf_proposals, + max_proposals_per_leaf=max_proposals_per_leaf, min_size_with_proposals=min_size_with_proposals, ) + proposals = proposal_generator(search_radius) self.search_radius = search_radius self.store_proposals(proposals) @@ -328,7 +316,7 @@ def proposal_length(self, proposal): return self.dist(*tuple(proposal)) def proposal_midpoint(self, proposal): - return geometry_util.midpoint(*self.proposal_xyz(proposal)) + return self.midpoint(*proposal) def proposal_radius(self, proposal): i, j = proposal diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index ca47684..c53beda 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -85,6 +85,7 @@ def __init__( self.node_spacing = node_spacing self.soma_centroids = list() self.soma_component_ids = list() + self.verbose = verbose # Graph Loader anisotropy = anisotropy if use_anisotropy else (1.0, 1.0, 1.0) @@ -295,15 +296,7 @@ def connect_soma_fragments(self, max_dist=25): def remove_soma_merges(self): """ - Breaks a fragment that intersects with multiple somas so that nodes - closest to soma locations are disconnected. - - Parameters - ---------- - swc_dict : dict - Contents of an SWC file. - somas_xyz : List[Tuple[float]] - Physical coordinates representing soma locations. + Removes branching points along paths that connect multiple somas. """ # Check whether to exit if len(self.soma_centroids) == 0: @@ -316,6 +309,7 @@ def remove_soma_merges(self): component_id_to_soma_nodes[self.node_component_id[i]].add(i) # Check for connected components with multiple soma nodes + merge_sites = list() n_soma_merges = 0 for component_id, soma_nodes in component_id_to_soma_nodes.items(): if len(soma_nodes) > 1 and len(soma_nodes) < 20: @@ -325,16 +319,14 @@ def remove_soma_merges(self): dist, _ = somas_kdtree.query(self.node_xyz[i]) if self.degree[i] > 2 and dist > 25: rm_nodes.add(i) + merge_sites.append(self.node_xyz[i]) self.remove_nearby_nodes(rm_nodes) # Finish self.remove_small_components(relabel_nodes=False) self.relabel_nodes() - results = [ - f"# Soma Fragments: {len(self.soma_centroids)}", - f"# Soma Merges: {n_soma_merges}", - ] - return "\n".join(results) + summary = f"# Soma Merges: {n_soma_merges}" + return merge_sites, summary def remove_nearby_nodes(self, roots, max_dist=5.0): """ @@ -853,6 +845,43 @@ def irreducible_nodes(self): """ return {i for i in map(int, self.nodes) if self.degree[i] != 2} + def irreducible_paths(self): + """ + Extracts non-branching paths between irreducible nodes (degree 1 or >= 3). + + Returns + ------- + paths : List[numpy.ndarray] + Each entry is an ordered list of node IDs forming a path between two + irreducible nodes, inclusive of both endpoints. + """ + # Initializations + irreducible = {n for n in self.nodes if self.degree(n) != 2} + paths = [] + visited_edges = set() + + # Search + for source in irreducible: + for nb in self.neighbors(source): + edge = frozenset((source, nb)) + if edge in visited_edges: + continue + visited_edges.add(edge) + + # Walk along degree-2 chain until next irreducible node + path = [source, nb] + prev, curr = source, nb + while curr not in irreducible: + nxt = next(n for n in self.neighbors(curr) if n != prev) + edge = frozenset((curr, nxt)) + visited_edges.add(edge) + path.append(nxt) + prev, curr = curr, nxt + + paths.append(np.array(path, dtype=int)) + + return paths + def leaf_nodes(self): """ Gets all leaf nodes in the graph. @@ -864,6 +893,9 @@ def leaf_nodes(self): """ return [i for i in self.nodes if self.degree[i] == 1] + def midpoint(self, i, j): + return geometry_util.midpoint(self.node_xyz[i], self.node_xyz[j]) + def node_local_voxel(self, node, offset): """ Computes the local voxel coordinate of the given node within the image @@ -1053,9 +1085,7 @@ def path_length(self, path): """ if len(path) > 1: diffs = self.node_xyz[path[1:]] - self.node_xyz[path[:-1]] - return np.sqrt(np.sum(diffs**2)) - else: - return 0 + return np.linalg.norm(diffs**2, axis=1).sum() def path_thru_node(self, i, max_depth=np.inf): if self.degree[i] == 0: @@ -1089,8 +1119,8 @@ def remove_high_risk_merges(self, max_dist=7): somas_kdtree = KDTree(self.node_xyz[soma_nodes]) # Iterate over branching nodes - cnt = 0 rm_nodes = set() + merge_sites = list() while branching_nodes: # Set root of search root = branching_nodes.pop() @@ -1110,6 +1140,7 @@ def remove_high_risk_merges(self, max_dist=7): i, dist_i = queue.pop() if self.degree[i] > 2 and i != root: hit_branching_nodes.add(i) + merge_sites.append(self.midpoint(root, i)) # Update queue for j in self.neighbors(i): @@ -1122,9 +1153,11 @@ def remove_high_risk_merges(self, max_dist=7): if hit_branching_nodes or self.degree(root) > 3: rm_nodes = rm_nodes.union(visited) branching_nodes -= hit_branching_nodes - cnt += 1 + + # Update graph self.remove_nodes(rm_nodes) - return f"# High Risk Merges: {cnt}" + summary = f"# High Risk Merges: {len(merge_sites)}" + return merge_sites, summary def rooted_subgraph(self, root, radius): """ @@ -1184,38 +1217,19 @@ def soma_nodes(self): soma_nodes.append(i) return soma_nodes - def summary(self, prefix=""): - """ - Generate a human-readable summary of the graph. - - Parameters - ---------- - prefix : str, optional - Optional string to prepend to the summary title. - - Returns - ------- - summary : str - Formatted multi-line string containing: - - Graph Name - - Number of connected components - - Number of nodes - - Number of edges - - Memory consumption (in GBs) - """ - # Compute values + def __repr__(self): n_components = format(nx.number_connected_components(self), ",") n_nodes = format(self.number_of_nodes(), ",") n_edges = format(self.number_of_edges(), ",") memory = util.get_memory_usage() - - # Compile results - summary = [f"{prefix} Graph"] - summary.append(f"# Connected Components: {n_components}") - summary.append(f"# Nodes: {n_nodes}") - summary.append(f"# Edges: {n_edges}") - summary.append(f"Memory Consumption: {memory:.2f} GBs") - return "\n".join(summary) + return ( + f" SkeletonGraph(\n" + f" num_connected_components={n_components},\n" + f" num_nodes={n_nodes},\n" + f" num_edges={n_edges},\n" + f" memory={memory:.2f} GBs,\n" + f" )" + ) def swc_ids(self): """ diff --git a/src/neuron_proofreader/split_proofreading/proposal_generation.py b/src/neuron_proofreader/split_proofreading/proposal_generation.py index 729597c..9a70cc6 100644 --- a/src/neuron_proofreader/split_proofreading/proposal_generation.py +++ b/src/neuron_proofreader/split_proofreading/proposal_generation.py @@ -24,8 +24,10 @@ class ProposalGenerator: def __init__( self, graph, + allow_nonleaf_proposals=False, max_attempts=2, max_proposals_per_leaf=3, + min_size_with_proposals=0, search_scaling_factor=1.5, ): """ @@ -35,29 +37,31 @@ def __init__( ---------- graph : ProposalGraph Graph that proposals will be generated for. + allow_nonleaf_proposals : bool, optional + Indication of whether to generate proposals between leaf and nodes + with degree 2. Default is False. max_attempts : int, optional Number of attempts made to generate proposals from a node with increasing search radii. Default is 2. max_proposals_per_leaf : bool, optional Maximum number of proposals generated at each leaf. Default is 3. + min_size_with_proposals : float, optional + Minimum cable path length required for fragments that proposals + are generated from. Default is 0. search_scaling_factor : 1.5, optional Scaling actor used to enlarge search radius for each search. Default is 2. """ # Instance attributes - self.allow_nonleaf_proposals = None + self.allow_nonleaf_proposals = allow_nonleaf_proposals self.graph = graph self.kdtree = None self.max_attempts = max_attempts self.max_proposals_per_leaf = max_proposals_per_leaf + self.min_size_with_proposals = min_size_with_proposals self.search_scaling_factor = search_scaling_factor - def __call__( - self, - initial_radius, - allow_nonleaf_proposals=False, - min_size_with_proposals=0, - ): + def __call__(self, initial_radius): """ Generates edge proposals between fragments within the given search radius. @@ -67,15 +71,8 @@ def __call__( initial_radius : float Initial search radius used to generate proposals between endpoints of proposal. - allow_nonleaf_proposals : bool, optional - Indication of whether to generate proposals between leaf and nodes - with degree 2. Default is False. - min_size_with_proposals : float, optional - Minimum cable path length required for fragments that proposals - are generated from. Default is 0. """ # Initializations - self.allow_nonleaf_proposals = allow_nonleaf_proposals self.set_kdtree() iterator = self.graph.leaf_nodes() if self.graph.verbose: @@ -87,9 +84,9 @@ def __call__( for leaf in iterator: # Check if fragment satisfies size requirement length = self.graph.cable_length( - max_depth=min_size_with_proposals, root=leaf + max_depth=self.min_size_with_proposals, root=leaf ) - if length < min_size_with_proposals: + if length < self.min_size_with_proposals: continue # Generate proposals diff --git a/src/neuron_proofreader/split_proofreading/split_datasets.py b/src/neuron_proofreader/split_proofreading/split_datasets.py index ad81b22..635ef53 100644 --- a/src/neuron_proofreader/split_proofreading/split_datasets.py +++ b/src/neuron_proofreader/split_proofreading/split_datasets.py @@ -17,8 +17,6 @@ import os import pandas as pd -from neuron_proofreader.proposal_graph import ProposalGraph -from neuron_proofreader.machine_learning.augmentation import ImageTransforms from neuron_proofreader.machine_learning.subgraph_sampler import ( SubgraphSampler, ) @@ -26,10 +24,10 @@ FeaturePipeline, HeteroGraphData, ) -from neuron_proofreader.utils import geometry_util, util +from neuron_proofreader.utils import util -# --- Single Brain Dataset --- +# --- Datasets --- class FragmentsDataset(IterableDataset): """ A dataset object that contains a graph built from fragments corresponding @@ -39,13 +37,11 @@ class FragmentsDataset(IterableDataset): def __init__( self, - fragments_path, - img_path, - config, + fragments_graph, + img_config, + batch_size=32, gt_path=None, - metadata_path=None, prefetch=4, - soma_centroids=set(), ): """ Instantiates a FragmentsDataset object. @@ -56,65 +52,26 @@ def __init__( Path to predicted SWC files to be loaded. img_path : str Path to the raw image associated with the fragments. - config : Config - Configuration object containing parameters and settings. + graph_config : GraphConfig + ... gt_path : str, optional Path to ground-truth SWC files to be loaded. Default is None. - metadata_path : str, optional - Patch to JSON file containing metadata on block that fragments - were extracted from. Default is None. - soma_centroids : List[Tuple[int]], optional - Phyiscal coordinates of soma centroids. Default is an empty list. """ # Instance attributes - self.config = config + self.batch_size = batch_size + self.graph = fragments_graph self.gt_path = gt_path self.prefetch = prefetch - self.transform = ImageTransforms() if config.ml.transform else False - - # Build graph - self._load_graph(fragments_path, metadata_path) - self.graph.load_somas(soma_centroids) + self.transform = img_config.transform # Feature extractor self.feature_extractor = FeaturePipeline( self.graph, - img_path, - brightness_clip=self.config.ml.brightness_clip, - patch_shape=self.config.ml.patch_shape, + img_config.img_path, + brightness_clip=img_config.brightness_clip, + patch_shape=img_config.patch_shape, ) - def _load_graph(self, fragments_path, metadata_path=None): - """ - Loads a graph by reading and processing SWC files specified by the - given path. - - Parameters - ---------- - fragments_path : str - Path to SWC files to be loaded. - metadata_path : str, optional - Patch to JSON file containing metadata on block that fragments - were extracted from. Default is None. - """ - # Build graph - self.graph = ProposalGraph( - anisotropy=self.config.graph.anisotropy, - gt_path=self.gt_path, - min_cable_length=self.config.graph.min_cable_length, - node_spacing=self.config.graph.node_spacing, - prune_depth=self.config.graph.prune_depth, - verbose=self.config.graph.verbose, - ) - self.graph.load(fragments_path) - - # Post process fragments - if metadata_path: - self.graph.clip_to_bbox(metadata_path) - - if self.config.graph.remove_doubles: - geometry_util.remove_doubles(self.graph, 200) - # --- Get Data --- def __iter__(self): """ @@ -137,18 +94,17 @@ def __getattr__(self, name): def get_sampler(self): """ - Gets a subgraph sampler that is used to iterate over dataset. + Gets a subgraph sampler used to iterate over dataset. Returns ------- sampler : SubgraphSampler Subgraph sampler that is used to iterate over dataset. """ - batch_size = self.config.ml.batch_size - return iter(SubgraphSampler(self.graph, max_proposals=batch_size)) + sampler = SubgraphSampler(self.graph, max_proposals=self.batch_size) + return iter(sampler) -# --- Multi-Brain Dataset --- class FragmentsDatasetCollection(IterableDataset): """ A dataset class for storing a set of FragmentDataset objects corresponding diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index 7976898..7330221 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -18,6 +18,9 @@ import numpy as np import torch +from neuron_proofreader.machine_learning.image_dataloader import ( + TensorStoreImage, +) from neuron_proofreader.utils import geometry_util, graph_util, img_util, util from neuron_proofreader.utils.ml_util import TensorDict @@ -242,15 +245,12 @@ def __init__( Number of voxels to be added in each dimension from start and end point of proposal for image patch extraction. Default is 40. """ - # Instance attributes self.brightness_clip = brightness_clip self.graph = graph + self.img = TensorStoreImage(img_path) self.patch_shape = patch_shape self.padding = padding - # Image reader - self.img = img_util.TensorStoreImage(img_path) - def __call__(self, subgraph, features): """ Extracts image patches and profiles for each proposal in the graph. diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index ed84f02..e9055d3 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -7,24 +7,19 @@ Code that executes the full split correction pipeline. Inference Pipeline: - 1. Graph Construction - Build graph from neuron fragments. - - 2. Proposal Generation + 1. Proposal Generation Generate proposals for potential connections between fragments. - 3. Proposal Classification + 2. Proposal Classification a. Feature Generation Extract features from proposals and graph for a machine learning model. b. Predict with Graph Neural Network (GNN) Run a GNN to classify proposals as accept/reject based on the learned features. - c. Merge Accepted Proposals - Add accepted proposals to the graph as edges. -Note: Steps 2 and 3 of the inference pipeline can be iterated in a loop that - repeats multiple times by calling the pipeline in a loop + 3. Merge Accepted Proposals + Add accepted proposals as edges to the graph. """ @@ -32,7 +27,6 @@ from tqdm import tqdm import networkx as nx -import numpy as np import pandas as pd import os import torch @@ -43,24 +37,20 @@ from neuron_proofreader.utils import ml_util, util -class InferencePipeline: +class SplitProofreader: """ - Class that executes the full split proofreader inference pipeline by - performing the following steps: - (1) Graph Construction - (2) Proposal Generation - (3) Proposal Classification. + Class that executes the full split proofreader inference pipeline. """ def __init__( self, - fragments_path, - img_path, - output_dir, + graph, model, - config, - log_preamble="", - soma_centroids=list(), + img_config, + output_dir, + batch_size=32, + device="cuda", + log_handle=None, ): """ Initializes an object that executes the full split correction @@ -68,97 +58,46 @@ def __init__( Parameters ---------- - fragments_path : str - Path to SWC files to be loaded into graph. - img_path : str - Path to whole-brain image corresponding to the given fragments. - output_dir : str - Directory where the results of the inference will be saved. - config : Config - Configuration object containing parameters and settings required - for the inference pipeline. - log_preamble : str, optional - String to be added to the beginning of log. Default is an empty - string. - soma_centroids : List[Tuple[float]], optional - Physcial coordinates of soma centroids. Default is an empty list. + ... """ # Instance attributes - self.accepted_proposals = list() - self.config = config - self.img_path = img_path - self.model = model.to(config.ml.device) - self.output_dir = output_dir - self.soma_centroids = soma_centroids - - # Logger - util.mkdir(self.output_dir) - log_path = os.path.join(self.output_dir, "runtimes.txt") - self.log_handle = open(log_path, "a") - self.log(log_preamble) - - # Load data - self._load_data(fragments_path, img_path) - - def _load_data(self, fragments_path, img_path): - """ - Builds a graph from the given fragments. - - Parameters - ---------- - fragments_path : str - Path to SWC files to be loaded into graph. - img_path : str - Path to whole-brain image corresponding to the given fragments. - """ - # Load data - t0 = time() - self.log("Step 1: Build Graph") self.dataset = FragmentsDataset( - fragments_path, - img_path, - self.config, - soma_centroids=self.soma_centroids, + graph, + img_config, + batch_size=batch_size, ) - self.log(self.dataset.summary(prefix="\nInitial")) - self.save_fragment_ids() - self.save_graph("original_swcs") - - # Postprocess fragments with somas - self.log(self.dataset.remove_soma_merges()) - self.log(self.dataset.connect_soma_fragments()) - - # Break high risk merges (if applicable) - if self.config.graph.remove_high_risk_merges: - self.log(self.dataset.remove_high_risk_merges()) - self.log(self.dataset.summary(prefix="\nPre-Corrected")) - self.save_graph("precorrected_swcs") + self.device = device + self.model = model + self.output_dir = output_dir - # Report runtime - elapsed, unit = util.time_writer(time() - t0) - self.log(f"Module Runtime: {elapsed:.2f} {unit}\n") + # Logger + log_path = os.path.join(self.output_dir, "summary.txt") + self.log_handle = log_handle or open(log_path, "a") - # --- Pipelines --- def __call__( - self, search_radius, dt=0.1, min_threshold=0.75, removal_threshold=0.3 + self, + proposals_config, + dt=0.1, + min_threshold=0.8, + removal_threshold=0.3, ): """ Executes the full inference pipeline. Parameters ---------- - search_radius : float - Search radius (in microns) used to generate proposals. + proposals_config : ProposalsConfig + Config object with settings for proposal generation. dt : float, optional Increment that acceptance threshold is lowered by. Default is 0.1. min_threshold : float, optional - Minimum threshold for accepting proposals. Default is 0.75. + Minimum threshold for accepting proposals. Default is 0.8. removal_threshold : float, optional Proposals with model predictions less than this value are removed. Default is 0.3. """ # Generate proposals - self.generate_proposals(search_radius) + self.generate_proposals(proposals_config) total_proposals = self.dataset.n_proposals() # Run inference @@ -173,80 +112,63 @@ def __call__( # Generate predictions while self.dataset.proposals: - # Generate predictons + # Generate proposal predictons cnt += 1 self.log( - f"\nThreshold={new_threshold} w/ only_leaf2leaf={only_leaf2leaf}" + f"\n--- Threshold={new_threshold} w/ only_leaf2leaf={only_leaf2leaf} ---" ) preds = self.predict_proposals( suffix=f"{name}_round={cnt}_threshold={new_threshold}" ) - # Merge accetped proposals + # Merge accepted proposals cur_threshold = new_threshold self.merge_with_threshold_schedule( - preds, cur_threshold, only_leaf2leaf=only_leaf2leaf + preds, cur_threshold, dt=dt, only_leaf2leaf=only_leaf2leaf ) - self.filter_proposals(preds, removal_threshold) - # Update threshold + # Remove rejected proposals + self.remove_proposals(preds, removal_threshold) + + # Update acceptance threshold new_threshold = max(cur_threshold - dt, min_threshold) if cur_threshold == new_threshold: break # Report results t, unit = util.time_writer(time() - t0) - p_accepts = len(self.dataset.accepts) / total_proposals - self.log(self.dataset.summary(prefix="\nFinal")) - self.log(f"Overall Acceptance Rate: {p_accepts:.2f}") + p_accepts = 100 * len(self.dataset.accepts) / total_proposals + self.log(f"Overall Accepted: {p_accepts:.2f}%") self.log(f"Total Runtime: {t:.2f} {unit}\n") - self.save_results() + self.save_connections() # --- Core Routines --- - def filter_proposals(self, preds, threshold): - # Remove based on model predictions and mergeability - cnt = 0 - for proposal, pred in preds.items(): - is_valid = self.dataset.is_mergeable(*proposal) - if pred < threshold or not is_valid: - self.dataset.remove_proposal(proposal) - cnt += 1 - - # Sanity check - for proposal in self.dataset.list_proposals(): - i, j = proposal - if self.dataset.degree[i] > 2 or self.dataset.degree[j] > 2: - self.dataset.remove_proposal(proposal) - cnt += 1 - - self.log("Filter Proposals") - self.log(f"# Proposals Removed: {cnt}") - self.log(f"# Proposals Remaining: {self.dataset.n_proposals()}\n") - - def generate_proposals(self, search_radius): + def generate_proposals(self, proposals_config): """ Generates proposals for the fragments graph based on the specified configuration. Parameters ---------- - search_radius : float - Search radius (in microns) used to generate proposals. + proposals_config : ProposalsConfig + Config object with settings for proposal generation. """ # Main t0 = time() - self.log("\nStep 2: Generate Proposals") - self.log(f"Search Radius: {search_radius}") + self.log("Generate Proposals...") self.dataset.generate_proposals( - search_radius, - allow_nonleaf_proposals=self.config.graph.allow_nonleaf_proposals, + proposals_config.search_radius, + allow_nonleaf_proposals=proposals_config.allow_nonleaf_proposals, + max_proposals_per_leaf=proposals_config.max_proposals_per_leaf, + min_size_with_proposals=proposals_config.min_size_with_proposals, ) + # Report results n_proposals = format(self.dataset.n_proposals(), ",") n_proposals_blocked = self.dataset.n_proposals_blocked - - # Report results t, unit = util.time_writer(time() - t0) + + self.log(f"Search Radius: {proposals_config.search_radius}") self.log(f"# Proposals: {n_proposals}") self.log(f"# Proposals Blocked: {n_proposals_blocked}") self.log(f"Module Runtime: {t:.2f} {unit}\n") @@ -273,7 +195,6 @@ def merge_with_threshold_schedule( """ # Initializations t0 = time() - self.log("\nStep 3: Run Inference") n_proposals = self.dataset.n_proposals() n_accepts = 0 @@ -293,6 +214,7 @@ def merge_with_threshold_schedule( # Report results t, unit = util.time_writer(time() - t0) + self.log("Inference...") self.log(f"# Merges Blocked: {self.dataset.n_merges_blocked}") self.log(f"# Accepted: {format(n_accepts, ',')}") self.log(f"% Accepted: {100 * n_accepts / n_proposals:.2f}") @@ -316,13 +238,13 @@ def predict_proposals(self, suffix=""): pbar.update(data.n_proposals()) # Save results - self.save_proposal_results(preds, suffix=suffix) + self.save_model_predictions(preds, suffix=suffix) return preds def merge_proposals(self, preds, threshold, only_leaf2leaf=False): """ - Merges nodes corresponding to for proposals that satify the threshold - and no loop creation requirements. + Merges proposals with model prediction above threshold and does + not create a loop. Parameters ---------- @@ -355,16 +277,25 @@ def merge_proposals(self, preds, threshold, only_leaf2leaf=False): del preds[proposal] return n_accepts - def save_results(self): - """ - Saves the processed results from running the inference pipeline, - namely the corrected SWC files and a list of the merged SWC ids. - """ - self.reconfigure_node_radius() - self.save_graph("corrected_swcs") - self.save_connections() - self.config.save(self.output_dir) - self.log_handle.close() + def remove_proposals(self, preds, threshold): + # Remove based on model predictions and mergeability + cnt = 0 + for proposal, pred in preds.items(): + is_valid = self.dataset.is_mergeable(*proposal) + if pred < threshold or not is_valid: + self.dataset.remove_proposal(proposal) + cnt += 1 + + # Sanity check + for proposal in self.dataset.list_proposals(): + i, j = proposal + if self.dataset.degree[i] > 2 or self.dataset.degree[j] > 2: + self.dataset.remove_proposal(proposal) + cnt += 1 + + self.log("Remove Proposals...") + self.log(f"# Proposals Removed: {cnt}") + self.log(f"# Proposals Remaining: {self.dataset.n_proposals()}\n") # --- Helpers --- def log(self, txt): @@ -396,8 +327,7 @@ def predict(self, data): """ # Generate predictions with torch.inference_mode(): - device = self.config.ml.device - x = data.get_inputs().to(device) + x = data.get_inputs().to(self.device) with torch.cuda.amp.autocast(enabled=True): hat_y = torch.sigmoid(self.model(x)) @@ -406,20 +336,17 @@ def predict(self, data): hat_y = ml_util.tensor_to_list(hat_y) return {idx_to_id[i]: y_i for i, y_i in enumerate(hat_y)} - def save_graph(self, dirname): - # Set paths - temp_dir = os.path.join(self.output_dir, "temp") - output_zip_path = os.path.join(self.output_dir, dirname, "swcs.zip") - util.mkdir(temp_dir) - util.mkdir(os.path.join(self.output_dir, dirname)) - - # Save swcs - self.dataset.to_zipped_swcs_multithreaded(temp_dir) - zip_paths = util.list_paths(temp_dir, extension=".zip") - util.combine_zips(zip_paths, output_zip_path) - util.rmdir(temp_dir) + def save_connections(self): + """ + Writes accepted proposals to a text file. Each line contains the two + SWC IDs as comma separated values. + """ + path = os.path.join(self.output_dir, "connections.txt") + with open(path, "w") as f: + for id1, id2 in self.dataset.merged_ids: + f.write(f"{id1}, {id2}" + "\n") - def save_proposal_results(self, preds_dict, suffix=""): + def save_model_predictions(self, preds_dict, suffix=""): summary = list() for proposal, pred in preds_dict.items(): # Extract info @@ -446,25 +373,3 @@ def save_proposal_results(self, preds_dict, suffix=""): # Save results path = os.path.join(self.output_dir, f"proposal_summary{suffix}.csv") pd.DataFrame(summary).set_index("Proposal").to_csv(path) - - def reconfigure_node_radius(self): - n_nodes = len(self.dataset.node_radius) - self.dataset.node_radius = np.ones((n_nodes), dtype=np.float16) - for i, j in self.dataset.accepts: - self.dataset.node_radius[i] = 6 - self.dataset.node_radius[j] = 6 - - def save_connections(self): - """ - Writes the accepted proposals to a text file. Each line contains the - two SWC IDs as comma separated values. - """ - path = os.path.join(self.output_dir, "connections.txt") - with open(path, "w") as f: - for id1, id2 in self.dataset.merged_ids: - f.write(f"{id1}, {id2}" + "\n") - - def save_fragment_ids(self): - path = f"{self.output_dir}/segment_ids.txt" - segment_ids = list(self.dataset.component_id_to_swc_id.values()) - util.write_list(path, segment_ids) diff --git a/src/neuron_proofreader/utils/graph_util.py b/src/neuron_proofreader/utils/graph_util.py index dad3464..6a4ca7b 100644 --- a/src/neuron_proofreader/utils/graph_util.py +++ b/src/neuron_proofreader/utils/graph_util.py @@ -120,7 +120,9 @@ def __call__(self, swc_pointer): # Continue submitting processes if swc_dicts: - pending.add(executor.submit(self.load, swc_dicts.pop())) + pending.add( + executor.submit(self.load, swc_dicts.pop()) + ) return irreducibles def load(self, swc_dict): diff --git a/src/neuron_proofreader/utils/img_util.py b/src/neuron_proofreader/utils/img_util.py index c0972f0..bbee524 100644 --- a/src/neuron_proofreader/utils/img_util.py +++ b/src/neuron_proofreader/utils/img_util.py @@ -203,15 +203,13 @@ def compute_iou3d(c1, c2, s1, s2): return np.prod(overlap) / union if union > 0 else 0 -def find_img_path(bucket_name, root_dir, brain_id): +def find_img_path(root_prefix, brain_id): """ Finds the path to a whole-brain dataset stored in a GCS bucket. Parameters: ---------- - bucket_name : str - Name of the GCS bucket where the images are stored. - root_dir : str + root_prefrix : str Path to the directory in the GCS bucket where the image is expected to be located. dataset_name : str @@ -222,11 +220,12 @@ def find_img_path(bucket_name, root_dir, brain_id): str Path of the found dataset subdirectory within the specified GCS bucket. """ - for subdir in util.list_gcs_subdirectories(bucket_name, root_dir): - if brain_id in subdir: - img_path = f"gs://{bucket_name}/{subdir}whole-brain/fused.zarr" + bucket_name, _ = util.parse_cloud_path(root_prefix) + for prefix in util.list_gcs_subprefixes(root_prefix): + if brain_id in prefix: + img_path = f"gs://{bucket_name}/{prefix}whole-brain/fused.zarr" return img_path - raise f"Dataset not found in {bucket_name} - {root_dir}" + raise f"Dataset not found in {root_prefix}" def get_contained_voxels(voxels, shape, buffer=0): diff --git a/src/neuron_proofreader/utils/util.py b/src/neuron_proofreader/utils/util.py index 4056cbf..df62329 100644 --- a/src/neuron_proofreader/utils/util.py +++ b/src/neuron_proofreader/utils/util.py @@ -401,7 +401,7 @@ def parse_cloud_path(path): """ # Remove s3:// or gs:// if present if path.startswith("s3://") or path.startswith("gs://"): - path = path[len("s3://"):] + path = path[len("s3://") :] # Split path parts = path.split("/", 1)