Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/neuron_proofreader/machine_learning/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def __init__(self, axes=(0, 1, 2)):
Parameters
----------
axes : Tuple[float], optional
Tuple of integers representing the axes along which to flip the
image. Default is (0, 1, 2).
Axes along which to flip the image. Default is (0, 1, 2).
"""
self.axes = axes

Expand All @@ -79,8 +78,8 @@ def __call__(self, patches):
"""
for axis in self.axes:
if random.random() > 0.5:
patches[0, ...] = np.flip(patches[0, ...], axis=axis)
patches[1, ...] = np.flip(patches[1, ...], axis=axis)
patches[0] = np.flip(patches[0], axis=axis)
patches[1] = np.flip(patches[1], axis=axis)
return patches


Expand Down Expand Up @@ -116,8 +115,8 @@ def __call__(self, patches):
for axes in self.axes:
if random.random() < 0.5:
angle = random.uniform(*self.angles)
self.rotate3d(patches[0, ...], angle, axes, False)
self.rotate3d(patches[1, ...], angle, axes, True)
patches[0] = self.rotate3d(patches[0], angle, axes, False)
patches[1] = self.rotate3d(patches[1], angle, axes, True)
return patches

@staticmethod
Expand Down Expand Up @@ -149,6 +148,7 @@ def rotate3d(img_patch, angle, axes, is_segmentation=False):
order=order,
)
img_patch /= multipler
return img_patch


class RandomScale3D:
Expand Down Expand Up @@ -197,8 +197,8 @@ def __call__(self, patches):
]

# Rescale images
patches[0, ...] = zoom(patches[0, ...], zoom_factors, order=3)
patches[1, ...] = zoom(patches[1, ...], zoom_factors, order=0)
patches[0] = zoom(patches[0], zoom_factors, order=3)
patches[1] = zoom(patches[1], zoom_factors, order=0)
return patches


Expand All @@ -208,7 +208,7 @@ class RandomContrast3D:
Adjusts the contrast of a 3D image by scaling voxel intensities.
"""

def __init__(self, p_low=(0, 90), p_high=(97.5, 100)):
def __init__(self, p_low=(0, 80), p_high=(98, 100)):
"""
Initializes a RandomContrast3D transformer.

Expand Down Expand Up @@ -253,7 +253,7 @@ def __init__(self, max_std=0.2):
"""
self.max_std = max_std

def __call__(self, img_patches):
def __call__(self, patches):
"""
Adds Gaussian noise to the input 3D image.

Expand All @@ -264,6 +264,6 @@ def __call__(self, img_patches):
the input image and "patches[1, ...]" is from the segmentation.
"""
std = self.max_std * random.random()
img_patches[0] += np.random.uniform(-std, std, img_patches[0].shape)
img_patches[0] = np.clip(img_patches[0], 0, 1)
return img_patches
patches[0] += np.random.uniform(-std, std, patches[0].shape)
patches[0] = np.clip(patches[0], 0, 1)
return patches
17 changes: 1 addition & 16 deletions src/neuron_proofreader/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,6 @@
Code that loads and preprocesses neuron fragments stored as SWC files, then
constructs a custom graph object called a "FragmentsGraph".

Graph Loading Algorithm:
1. Load Soma Locations (Optional)

2. Extract Irreducibles from SWC files
a. Build graph from SWC file
b. Break soma merges (optional)
c. Break high risk merges (optional)
d. Find irreducible nodes
e. Find irreducible edges


Note: We use the term "branch" to refer to a path in a graph from a branching
node to a leaf.
"""

from collections import deque
Expand Down Expand Up @@ -80,9 +67,7 @@ def __init__(
self.node_spacing = node_spacing
self.prefetch = prefetch
self.prune_depth = prune_depth
self.swc_reader = swc_util.Reader(
anisotropy, min_cable_length, verbose
)
self.swc_reader = swc_util.Reader(anisotropy, verbose)
self.verbose = verbose

def __call__(self, swc_pointer):
Expand Down
55 changes: 32 additions & 23 deletions src/neuron_proofreader/utils/swc_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Created on Thu May 21 12:00:00 2026
Created on Wed June 5 16:00:00 2023

@author: Anna Grim
@email: anna.grim@alleninstitute.org
Expand All @@ -14,17 +14,17 @@
"z" (float): z coordinate
"pid" (int): node ID of parent

Note: Each line in an SWC file corresponds to a node and contains these
attributes in the same order.
Note: Each uncommented line in an SWC file corresponds to a node and contains
these attributes in the same order.
"""

from botocore import UNSIGNED
from botocore.client import Config
from botocore.config import Config
from collections import deque
from concurrent.futures import (
as_completed,
ProcessPoolExecutor,
ThreadPoolExecutor,
as_completed,
)
from google.auth.exceptions import RefreshError, TransportError
from google.cloud import storage
Expand Down Expand Up @@ -92,6 +92,10 @@ def __call__(self, swc_pointer):
- "filename": filename of SWC file
- "swc_id": name of SWC file, minus the ".swc".
"""
# List of paths
if isinstance(swc_pointer, list):
return self.read_swcs(swc_pointer)

# Directory containing...
if os.path.isdir(swc_pointer):
# Local ZIP archives with SWC files
Expand All @@ -102,7 +106,7 @@ def __call__(self, swc_pointer):
# Local SWC files
paths = util.read_paths(swc_pointer, extension=".swc")
if len(paths) > 0:
return self.read_swcs(paths, self.read_swc)
return self.read_swcs(paths)

raise Exception("Directory is Invalid!")

Expand Down Expand Up @@ -143,14 +147,12 @@ def read_swc(self, path):
filename = os.path.basename(path)
return self.parse(content, filename)

def read_swcs(self, swc_paths, read_fn):
def read_swcs(self, swc_paths):
"""
Reads SWC files stored in a GCS or S3 bucket.

Parameters
----------
bucket_name : str
Name of bucket containing SWC files.
swc_paths : List[str]
List of paths to SWC files to be read.

Expand All @@ -164,7 +166,7 @@ def read_swcs(self, swc_paths, read_fn):
# Assign threads
threads = set()
for path in swc_paths:
threads.add(executor.submit(read_fn, path))
threads.add(executor.submit(self.read_swc, path))

# Store results
swc_dicts = deque()
Expand Down Expand Up @@ -288,7 +290,7 @@ def read_from_cloud(self, path):

# Call reader
if swc_paths:
return self.read_swcs(swc_paths, self.read_swc)
return self.read_swcs(swc_paths)
elif zip_paths:
read_fn = self.read_s3_zip if use_s3 else self.read_gcs_zip
return self.read_zips(zip_paths, read_fn)
Expand Down Expand Up @@ -345,17 +347,17 @@ def read_gcs_zip(self, path):
print(f"Failed to read {path}!")
return deque()

# Parse ZIP contents
# Parse ZIP
swc_dicts = deque()
zip_content = bucket.blob(path).download_as_bytes()
with ZipFile(BytesIO(zip_content), "r") as zf:
with ThreadPoolExecutor() as executor:
# Assign threads
threads = set()
for name in zf.namelist():
if self.confirm_read(name):
threads.add(
executor.submit(self.read_zipped_swc, zf, name)
)
threads.add(
executor.submit(self.read_zipped_swc, zf, name)
)

# Process results
for thread in as_completed(threads):
Expand Down Expand Up @@ -460,6 +462,7 @@ def parse(self, content, filename):
"pid": np.zeros((len(content)), dtype=int),
"radius": np.zeros((len(content)), dtype=float),
"xyz": np.zeros((len(content), 3), dtype=np.int32),
"soma_nodes": set(),
"swc_name": swc_name,
}

Expand All @@ -471,6 +474,9 @@ def parse(self, content, filename):
swc_dict["radius"][i] = float(parts[-2])
swc_dict["xyz"][i] = self.read_coordinate(parts[2:5], offset)

if int(parts[1]) == 1:
swc_dict["soma_nodes"].add(parts[0])

# Convert radius from nanometers to microns
if swc_dict["radius"][0] > 100:
swc_dict["radius"] /= 1000
Expand Down Expand Up @@ -553,23 +559,23 @@ def write_points(
to_zipped_point(zf, filename, xyz, color=color, radius=radius)


def to_zipped_point(zip_writer, filename, xyz, color=None, radius=5):
def to_zipped_point(zf, filename, xyz, color=None, radius=5):
"""
Writes a point to an SWC file format, which is then stored in a ZIP
archive.

Parameters
----------
zip_writer : zipfile.ZipFile
ZipFile object that will store the generated SWC file.
zf : zipfile.ZipFile
ZipFile used to write the generated SWC file.
filename : str
Filename of SWC file.
xyz : ArrayLike
Point to be written to SWC file.
color : str, optional
Color of nodes. Default is None.
radius : float, optional
Radius of point. Default is 5um.
Radius (in microns) of point. Default is 5.
"""
with StringIO() as text_buffer:
# Preamble
Expand All @@ -582,7 +588,7 @@ def to_zipped_point(zip_writer, filename, xyz, color=None, radius=5):
text_buffer.write("\n" + f"1 5 {x} {y} {z} {radius} -1")

# Finish
zip_writer.writestr(filename, text_buffer.getvalue())
zf.writestr(filename, text_buffer.getvalue())


# --- Helpers ---
Expand All @@ -593,7 +599,7 @@ def get_segment_id(swc_name):
Parameters
----------
swc_name : str
SWC filename, expected to be in the format "{segment_id}.swc".
SWC filename in the format "{segment_id}.swc".

Returns
-------
Expand All @@ -609,7 +615,7 @@ def get_segment_id(swc_name):

def get_swc_name(path):
"""
Gets name of the SWC file loacted at the given path, minus the extension.
Gets name of the SWC file at the given path, minus the extension.

Parameters
----------
Expand All @@ -634,6 +640,9 @@ def to_graph(swc_dict):
----------
swc_dict : dict
Contents of an SWC file.
set_attrs : bool, optional
Indication of whether to set "xyz" and "radius" as graph-level
attributes. Default is False.

Returns
-------
Expand Down
Loading