From 3719e4484811b43df1f91143c0c7fdeed01f7205 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 27 May 2026 19:24:45 +0000 Subject: [PATCH 1/2] refactor: improved swc load, graph load, augmentation --- .../machine_learning/augmentation.py | 26 +++---- src/neuron_proofreader/utils/graph_util.py | 17 +---- src/neuron_proofreader/utils/swc_util.py | 67 ++++++++++--------- 3 files changed, 49 insertions(+), 61 deletions(-) diff --git a/src/neuron_proofreader/machine_learning/augmentation.py b/src/neuron_proofreader/machine_learning/augmentation.py index 8d7b172..2fd8555 100644 --- a/src/neuron_proofreader/machine_learning/augmentation.py +++ b/src/neuron_proofreader/machine_learning/augmentation.py @@ -62,8 +62,7 @@ def __init__(self, axes=(0, 1, 2)): Parameters ---------- axes : Tuple[float], optional - Tuple of integers representing the axes along which to flip the - image. Default is (0, 1, 2). + Axes along which to flip the image. Default is (0, 1, 2). """ self.axes = axes @@ -79,8 +78,8 @@ def __call__(self, patches): """ for axis in self.axes: if random.random() > 0.5: - patches[0, ...] = np.flip(patches[0, ...], axis=axis) - patches[1, ...] = np.flip(patches[1, ...], axis=axis) + patches[0] = np.flip(patches[0], axis=axis) + patches[1] = np.flip(patches[1], axis=axis) return patches @@ -116,8 +115,8 @@ def __call__(self, patches): for axes in self.axes: if random.random() < 0.5: angle = random.uniform(*self.angles) - self.rotate3d(patches[0, ...], angle, axes, False) - self.rotate3d(patches[1, ...], angle, axes, True) + patches[0] = self.rotate3d(patches[0], angle, axes, False) + patches[1] = self.rotate3d(patches[1], angle, axes, True) return patches @staticmethod @@ -149,6 +148,7 @@ def rotate3d(img_patch, angle, axes, is_segmentation=False): order=order, ) img_patch /= multipler + return img_patch class RandomScale3D: @@ -197,8 +197,8 @@ def __call__(self, patches): ] # Rescale images - patches[0, ...] = zoom(patches[0, ...], zoom_factors, order=3) - patches[1, ...] = zoom(patches[1, ...], zoom_factors, order=0) + patches[0] = zoom(patches[0], zoom_factors, order=3) + patches[1] = zoom(patches[1], zoom_factors, order=0) return patches @@ -208,7 +208,7 @@ class RandomContrast3D: Adjusts the contrast of a 3D image by scaling voxel intensities. """ - def __init__(self, p_low=(0, 90), p_high=(97.5, 100)): + def __init__(self, p_low=(0, 80), p_high=(98, 100)): """ Initializes a RandomContrast3D transformer. @@ -253,7 +253,7 @@ def __init__(self, max_std=0.2): """ self.max_std = max_std - def __call__(self, img_patches): + def __call__(self, patches): """ Adds Gaussian noise to the input 3D image. @@ -264,6 +264,6 @@ def __call__(self, img_patches): the input image and "patches[1, ...]" is from the segmentation. """ std = self.max_std * random.random() - img_patches[0] += np.random.uniform(-std, std, img_patches[0].shape) - img_patches[0] = np.clip(img_patches[0], 0, 1) - return img_patches + patches[0] += np.random.uniform(-std, std, patches[0].shape) + patches[0] = np.clip(patches[0], 0, 1) + return patches diff --git a/src/neuron_proofreader/utils/graph_util.py b/src/neuron_proofreader/utils/graph_util.py index fb77eda..2fbdf4d 100644 --- a/src/neuron_proofreader/utils/graph_util.py +++ b/src/neuron_proofreader/utils/graph_util.py @@ -7,19 +7,6 @@ Code that loads and preprocesses neuron fragments stored as SWC files, then constructs a custom graph object called a "FragmentsGraph". - Graph Loading Algorithm: - 1. Load Soma Locations (Optional) - - 2. Extract Irreducibles from SWC files - a. Build graph from SWC file - b. Break soma merges (optional) - c. Break high risk merges (optional) - d. Find irreducible nodes - e. Find irreducible edges - - -Note: We use the term "branch" to refer to a path in a graph from a branching - node to a leaf. """ from collections import deque @@ -80,9 +67,7 @@ def __init__( self.node_spacing = node_spacing self.prefetch = prefetch self.prune_depth = prune_depth - self.swc_reader = swc_util.Reader( - anisotropy, min_cable_length, verbose - ) + self.swc_reader = swc_util.Reader(anisotropy, verbose) self.verbose = verbose def __call__(self, swc_pointer): diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 903c784..d70a3ad 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -1,5 +1,5 @@ """ -Created on Thu May 21 12:00:00 2026 +Created on Wed June 5 16:00:00 2023 @author: Anna Grim @email: anna.grim@alleninstitute.org @@ -14,17 +14,17 @@ "z" (float): z coordinate "pid" (int): node ID of parent -Note: Each line in an SWC file corresponds to a node and contains these - attributes in the same order. +Note: Each uncommented line in an SWC file corresponds to a node and contains + these attributes in the same order. """ from botocore import UNSIGNED -from botocore.client import Config +from botocore.config import Config from collections import deque from concurrent.futures import ( - as_completed, ProcessPoolExecutor, ThreadPoolExecutor, + as_completed, ) from google.auth.exceptions import RefreshError, TransportError from google.cloud import storage @@ -92,6 +92,10 @@ def __call__(self, swc_pointer): - "filename": filename of SWC file - "swc_id": name of SWC file, minus the ".swc". """ + # List of paths + if isinstance(swc_pointer, list): + return self.read_swcs(swc_pointer) + # Directory containing... if os.path.isdir(swc_pointer): # Local ZIP archives with SWC files @@ -102,7 +106,7 @@ def __call__(self, swc_pointer): # Local SWC files paths = util.read_paths(swc_pointer, extension=".swc") if len(paths) > 0: - return self.read_swcs(paths, self.read_swc) + return self.read_swcs(paths) raise Exception("Directory is Invalid!") @@ -143,14 +147,12 @@ def read_swc(self, path): filename = os.path.basename(path) return self.parse(content, filename) - def read_swcs(self, swc_paths, read_fn): + def read_swcs(self, swc_paths): """ Reads SWC files stored in a GCS or S3 bucket. Parameters ---------- - bucket_name : str - Name of bucket containing SWC files. swc_paths : List[str] List of paths to SWC files to be read. @@ -164,7 +166,7 @@ def read_swcs(self, swc_paths, read_fn): # Assign threads threads = set() for path in swc_paths: - threads.add(executor.submit(read_fn, path)) + threads.add(executor.submit(self.read_swc, path)) # Store results swc_dicts = deque() @@ -194,7 +196,7 @@ def read_zips(self, zip_paths, read_fn): Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - pbar = self.manual_progress_bar(len(zip_paths)) + pbar = tqdm(total=len(zip_paths), desc="Read SWCs") with ProcessPoolExecutor() as executor: # Assign processes futures = {executor.submit(read_fn, path) for path in zip_paths} @@ -288,7 +290,7 @@ def read_from_cloud(self, path): # Call reader if swc_paths: - return self.read_swcs(swc_paths, self.read_swc) + return self.read_swcs(swc_paths) elif zip_paths: read_fn = self.read_s3_zip if use_s3 else self.read_gcs_zip return self.read_zips(zip_paths, read_fn) @@ -336,26 +338,22 @@ def read_gcs_zip(self, path): Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - # Download ZIP + # Initialize cloud reader + client = storage.Client() bucket_name, path = util.parse_cloud_path(path) - bucket = storage.Client().bucket(bucket_name) - try: - zip_content = bucket.blob(path).download_as_bytes() - except TransportError: - print(f"Failed to read {path}!") - return deque() + bucket = client.bucket(bucket_name) - # Parse ZIP contents + # Parse Zip swc_dicts = deque() + zip_content = bucket.blob(path).download_as_bytes() with ZipFile(BytesIO(zip_content), "r") as zf: with ThreadPoolExecutor() as executor: # Assign threads threads = set() for name in zf.namelist(): - if self.confirm_read(name): - threads.add( - executor.submit(self.read_zipped_swc, zf, name) - ) + threads.add( + executor.submit(self.read_zipped_swc, zf, name) + ) # Process results for thread in as_completed(threads): @@ -460,6 +458,7 @@ def parse(self, content, filename): "pid": np.zeros((len(content)), dtype=int), "radius": np.zeros((len(content)), dtype=float), "xyz": np.zeros((len(content), 3), dtype=np.int32), + "soma_nodes": set(), "swc_name": swc_name, } @@ -471,6 +470,9 @@ def parse(self, content, filename): swc_dict["radius"][i] = float(parts[-2]) swc_dict["xyz"][i] = self.read_coordinate(parts[2:5], offset) + if int(parts[1]) == 1: + swc_dict["soma_nodes"].add(parts[0]) + # Convert radius from nanometers to microns if swc_dict["radius"][0] > 100: swc_dict["radius"] /= 1000 @@ -547,10 +549,10 @@ def write_points( radius : float, optional Radius to be used in SWC file. Default is 10. """ - zf = ZipFile(zip_path, write_mode) + zip_writer = ZipFile(zip_path, write_mode) for i, xyz in enumerate(points): filename = prefix + str(i + 1) + ".swc" - to_zipped_point(zf, filename, xyz, color=color, radius=radius) + to_zipped_point(zip_writer, filename, xyz, color=color, radius=radius) def to_zipped_point(zip_writer, filename, xyz, color=None, radius=5): @@ -628,12 +630,15 @@ def get_swc_name(path): def to_graph(swc_dict): """ - Converts an SWC dictionary to a NetworkX graph with reindexed nodes. + Converts an SWC dict to a NetworkX graph with reindexed nodes. Parameters ---------- swc_dict : dict Contents of an SWC file. + set_attrs : bool, optional + Indication of whether to set "xyz" and "radius" as graph-level + attributes. Default is False. Returns ------- @@ -648,10 +653,8 @@ def to_graph(swc_dict): ] # Build graph with reindexed edges - graph = nx.Graph( - swc_name=swc_dict["swc_name"], - radius=swc_dict["radius"], - xyz=swc_dict["xyz"], - ) + graph = nx.Graph(swc_name=swc_dict["swc_name"]) graph.add_edges_from(edges) + graph.graph["xyz"] = swc_dict["xyz"] + graph.graph["radius"] = swc_dict["radius"] return graph From 0ab87bb68239364da186f3bacbc41a4f2bd4db85 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 27 May 2026 19:34:32 +0000 Subject: [PATCH 2/2] merge from v2 branch --- src/neuron_proofreader/utils/swc_util.py | 42 ++++++++++++++---------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index d70a3ad..6300fd5 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -196,7 +196,7 @@ def read_zips(self, zip_paths, read_fn): Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - pbar = tqdm(total=len(zip_paths), desc="Read SWCs") + pbar = self.manual_progress_bar(len(zip_paths)) with ProcessPoolExecutor() as executor: # Assign processes futures = {executor.submit(read_fn, path) for path in zip_paths} @@ -338,12 +338,16 @@ def read_gcs_zip(self, path): Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - # Initialize cloud reader - client = storage.Client() + # Download ZIP bucket_name, path = util.parse_cloud_path(path) - bucket = client.bucket(bucket_name) + bucket = storage.Client().bucket(bucket_name) + try: + zip_content = bucket.blob(path).download_as_bytes() + except TransportError: + print(f"Failed to read {path}!") + return deque() - # Parse Zip + # Parse ZIP swc_dicts = deque() zip_content = bucket.blob(path).download_as_bytes() with ZipFile(BytesIO(zip_content), "r") as zf: @@ -549,21 +553,21 @@ def write_points( radius : float, optional Radius to be used in SWC file. Default is 10. """ - zip_writer = ZipFile(zip_path, write_mode) + zf = ZipFile(zip_path, write_mode) for i, xyz in enumerate(points): filename = prefix + str(i + 1) + ".swc" - to_zipped_point(zip_writer, filename, xyz, color=color, radius=radius) + to_zipped_point(zf, filename, xyz, color=color, radius=radius) -def to_zipped_point(zip_writer, filename, xyz, color=None, radius=5): +def to_zipped_point(zf, filename, xyz, color=None, radius=5): """ Writes a point to an SWC file format, which is then stored in a ZIP archive. Parameters ---------- - zip_writer : zipfile.ZipFile - ZipFile object that will store the generated SWC file. + zf : zipfile.ZipFile + ZipFile used to write the generated SWC file. filename : str Filename of SWC file. xyz : ArrayLike @@ -571,7 +575,7 @@ def to_zipped_point(zip_writer, filename, xyz, color=None, radius=5): color : str, optional Color of nodes. Default is None. radius : float, optional - Radius of point. Default is 5um. + Radius (in microns) of point. Default is 5. """ with StringIO() as text_buffer: # Preamble @@ -584,7 +588,7 @@ def to_zipped_point(zip_writer, filename, xyz, color=None, radius=5): text_buffer.write("\n" + f"1 5 {x} {y} {z} {radius} -1") # Finish - zip_writer.writestr(filename, text_buffer.getvalue()) + zf.writestr(filename, text_buffer.getvalue()) # --- Helpers --- @@ -595,7 +599,7 @@ def get_segment_id(swc_name): Parameters ---------- swc_name : str - SWC filename, expected to be in the format "{segment_id}.swc". + SWC filename in the format "{segment_id}.swc". Returns ------- @@ -611,7 +615,7 @@ def get_segment_id(swc_name): def get_swc_name(path): """ - Gets name of the SWC file loacted at the given path, minus the extension. + Gets name of the SWC file at the given path, minus the extension. Parameters ---------- @@ -630,7 +634,7 @@ def get_swc_name(path): def to_graph(swc_dict): """ - Converts an SWC dict to a NetworkX graph with reindexed nodes. + Converts an SWC dictionary to a NetworkX graph with reindexed nodes. Parameters ---------- @@ -653,8 +657,10 @@ def to_graph(swc_dict): ] # Build graph with reindexed edges - graph = nx.Graph(swc_name=swc_dict["swc_name"]) + graph = nx.Graph( + swc_name=swc_dict["swc_name"], + radius=swc_dict["radius"], + xyz=swc_dict["xyz"], + ) graph.add_edges_from(edges) - graph.graph["xyz"] = swc_dict["xyz"] - graph.graph["radius"] = swc_dict["radius"] return graph