diff --git a/doc_template/source/conf.py b/doc_template/source/conf.py index b38833ac..6e55fb07 100644 --- a/doc_template/source/conf.py +++ b/doc_template/source/conf.py @@ -1,4 +1,5 @@ """Configuration file for the Sphinx documentation builder.""" + # # For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html diff --git a/src/neuron_proofreader/__init__.py b/src/neuron_proofreader/__init__.py index d0a85479..024d51f2 100644 --- a/src/neuron_proofreader/__init__.py +++ b/src/neuron_proofreader/__init__.py @@ -1,2 +1,3 @@ """Init package""" + __version__ = "0.0.0" diff --git a/src/neuron_proofreader/machine_learning/augmentation.py b/src/neuron_proofreader/machine_learning/augmentation.py index 91951eba..8d7b1722 100644 --- a/src/neuron_proofreader/machine_learning/augmentation.py +++ b/src/neuron_proofreader/machine_learning/augmentation.py @@ -30,7 +30,7 @@ def __init__(self): RandomFlip3D(), RandomRotation3D(), RandomNoise3D(), - RandomContrast3D() + RandomContrast3D(), ] def __call__(self, patches): diff --git a/src/neuron_proofreader/machine_learning/exaspim_dataloader.py b/src/neuron_proofreader/machine_learning/exaspim_dataloader.py index 997a7cca..7724dc86 100644 --- a/src/neuron_proofreader/machine_learning/exaspim_dataloader.py +++ b/src/neuron_proofreader/machine_learning/exaspim_dataloader.py @@ -328,9 +328,7 @@ def sample_bright_voxel(self, brain_id): pending = dict() for _ in range(self.prefetch_foreground_sampling): voxel = self.sample_interior_voxel(brain_id) - thread = executor.submit( - self.read_image, brain_id, voxel - ) + thread = executor.submit(self.read_image, brain_id, voxel) pending[thread] = voxel # Check if image patch is bright enough @@ -489,8 +487,20 @@ def _load_batch(self, start_idx): ) # Process results - img_patches = np.zeros((batch_size, 1,) + self.patch_shape) - mask_patches = np.zeros((batch_size, 1,) + self.patch_shape) + img_patches = np.zeros( + ( + batch_size, + 1, + ) + + self.patch_shape + ) + mask_patches = np.zeros( + ( + batch_size, + 1, + ) + + self.patch_shape + ) for i, process in enumerate(as_completed(processes)): img, mask = process.result() img_patches[i, 0, ...] = img diff --git a/src/neuron_proofreader/machine_learning/geometric_gnn_models.py b/src/neuron_proofreader/machine_learning/geometric_gnn_models.py index 4f65c90d..0c68d7f1 100644 --- a/src/neuron_proofreader/machine_learning/geometric_gnn_models.py +++ b/src/neuron_proofreader/machine_learning/geometric_gnn_models.py @@ -75,9 +75,7 @@ def __init__(self, ggnn_name, output_dim=64): # Set geometric gnn if ggnn_name == "egnn": self.geometric_gnn = EGNN( - in_node_dim=1, - hidden_dim=32, - out_node_dim=output_dim + in_node_dim=1, hidden_dim=32, out_node_dim=output_dim ) # --- Core Routines --- @@ -97,7 +95,9 @@ def forward(self, h, x, edge_index, batch): ) # Pool node embeddings - h_g, x_g, edge_index_g = self.pool_nonbranching_paths(h_g, x_g, edge_index_g) + h_g, x_g, edge_index_g = self.pool_nonbranching_paths( + h_g, x_g, edge_index_g + ) # Encode pooled graph h_g = self.encode_pooled_graph(h_g, x_g, edge_index_g) @@ -158,7 +158,9 @@ def pool_nonbranching_paths(self, h, x, edge_index): # Finish h_pooled = torch.stack(h_pooled, dim=0) x_pooled = torch.stack(x_pooled, dim=0) - edge_index_pooled = self.get_edge_index_pooled(edge_index, node_to_path) + edge_index_pooled = self.get_edge_index_pooled( + edge_index, node_to_path + ) return h_pooled, x_pooled, edge_index_pooled def get_adj_and_deg(self, edge_index, num_nodes): @@ -201,10 +203,13 @@ def extract_subgraph(self, h, x, edge_index, node_mask): id_map = {int(n): i for i, n in enumerate(node_ids.tolist())} edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]] edge_index_g = edge_index[:, edge_mask] - edge_index_g = torch.stack([ - torch.tensor([id_map[int(u)] for u in edge_index_g[0]]), - torch.tensor([id_map[int(v)] for v in edge_index_g[1]]) - ], dim=0) + edge_index_g = torch.stack( + [ + torch.tensor([id_map[int(u)] for u in edge_index_g[0]]), + torch.tensor([id_map[int(v)] for v in edge_index_g[1]]), + ], + dim=0, + ) return h_g, x_g, edge_index_g @staticmethod diff --git a/src/neuron_proofreader/machine_learning/gnn_models.py b/src/neuron_proofreader/machine_learning/gnn_models.py index c8954ec3..b5add009 100644 --- a/src/neuron_proofreader/machine_learning/gnn_models.py +++ b/src/neuron_proofreader/machine_learning/gnn_models.py @@ -26,6 +26,7 @@ class VisionHGAT(torch.nn.Module): Heterogeneous graph attention network that processes multimodal features such as image patches and feature vectors. """ + # Class attributes relations = [ str(("branch", "to", "branch")), @@ -47,7 +48,9 @@ def __init__( # Initial embeddings self.node_embedding = init_node_embedding(hidden_dim) - self.patch_embedding = init_patch_embedding(patch_shape, hidden_dim // 2) + self.patch_embedding = init_patch_embedding( + patch_shape, hidden_dim // 2 + ) # Message passing layers self.disable_msg_passing = disable_msg_passing @@ -58,7 +61,7 @@ def __init__( else: self.gat1 = self.init_gat(hidden_dim, hidden_dim, heads) self.gat2 = self.init_gat(hidden_dim * heads, hidden_dim, heads) - self.output = nn.Linear(hidden_dim * heads ** 2, 1) + self.output = nn.Linear(hidden_dim * heads**2, 1) # Initialize weights self.init_weights() @@ -81,9 +84,7 @@ def init_mlp_layers(self, hidden_dim, n_layers=2): for _ in range(n_layers): layers.append( nn_geometric.HeteroDictLinear( - hidden_dim, - hidden_dim, - types=("branch", "proposal") + hidden_dim, hidden_dim, types=("branch", "proposal") ) ) return layers @@ -160,10 +161,12 @@ def init_node_embedding(output_dim): dim_p = node_input_dims["proposal"] # Set node embedding layer - node_embedding = nn.ModuleDict({ - "branch": FeedForwardNet(dim_b, output_dim, 3), - "proposal": FeedForwardNet(dim_p, output_dim // 2, 3), - }) + node_embedding = nn.ModuleDict( + { + "branch": FeedForwardNet(dim_b, output_dim, 3), + "proposal": FeedForwardNet(dim_p, output_dim // 2, 3), + } + ) return node_embedding diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index bce9b6e3..20edafb4 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -68,7 +68,7 @@ def __init__( lr=1e-3, max_epochs=200, min_recall=0, - save_mistake_mips=False + save_mistake_mips=False, ): """ Instantiates a Trainer object. @@ -292,7 +292,7 @@ def compute_stats(y, hat_y): "f1": avg_f1, "precision": avg_prec, "recall": avg_recall, - "accuracy": avg_acc + "accuracy": avg_acc, } return stats @@ -426,7 +426,7 @@ def __init__( device="cuda", lr=1e-3, max_epochs=200, - save_mistake_mips=False + save_mistake_mips=False, ): """ Instantiates a DistributedTrainer object. @@ -452,7 +452,7 @@ def __init__( device=device, lr=lr, max_epochs=max_epochs, - save_mistake_mips=save_mistake_mips + save_mistake_mips=save_mistake_mips, ) # Check that multiple GPUs are available diff --git a/src/neuron_proofreader/machine_learning/vision_models.py b/src/neuron_proofreader/machine_learning/vision_models.py index bf99fe65..21b6760d 100644 --- a/src/neuron_proofreader/machine_learning/vision_models.py +++ b/src/neuron_proofreader/machine_learning/vision_models.py @@ -142,7 +142,7 @@ def __init__(self, checkpoint_path, model_config): checkpoint_path=checkpoint_path, model_config=model_config, task_head_config="binary_classifier", - freeze_encoder=True + freeze_encoder=True, ) # Instance attributes diff --git a/src/neuron_proofreader/merge_proofreading/merge_dataloading.py b/src/neuron_proofreader/merge_proofreading/merge_dataloading.py index 0bf8656d..d9c6fd50 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_dataloading.py +++ b/src/neuron_proofreader/merge_proofreading/merge_dataloading.py @@ -43,7 +43,9 @@ def load_fragments(dataset, is_test=False): sub_df = merge_sites_df.loc[merge_sites_df["brain_id"] == brain_id] for segmentation_id in sub_df["segmentation_id"].unique(): if (brain_id, segmentation_id) in target_pairs: - swc_pointer = f"{root}/{brain_id}/{segmentation_id}/merged_fragments.zip" + swc_pointer = ( + f"{root}/{brain_id}/{segmentation_id}/merged_fragments.zip" + ) dataset.load_fragment_graphs( brain_id, swc_pointer, use_anisotropy=False ) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index fe9f4420..63a838e0 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -29,7 +29,7 @@ subgraph_to_point_cloud, ) from neuron_proofreader.merge_proofreading.merge_dataloading import ( - get_brain_merge_sites + get_brain_merge_sites, ) from neuron_proofreader.skeleton_graph import SkeletonGraph from neuron_proofreader.utils import ( @@ -71,6 +71,7 @@ class MergeSiteDataset(Dataset): patch_shape : Tuple[int], optional Shape of the 3D image patches to extract. """ + random_negative_example_prob = 0.8 def __init__( @@ -121,9 +122,7 @@ def __init__( self.merge_site_kdtrees = dict() # --- Load Data --- - def load_fragment_graphs( - self, brain_id, swc_pointer, use_anisotropy=True - ): + def load_fragment_graphs(self, brain_id, swc_pointer, use_anisotropy=True): """ Loads fragments containing merge mistakes for a whole-brain dataset, then stores them in the "graphs" attribute. @@ -139,7 +138,7 @@ def load_fragment_graphs( graph = SkeletonGraph( anisotropy=self.anisotropy, node_spacing=self.node_spacing, - use_anisotropy=use_anisotropy + use_anisotropy=use_anisotropy, ) graph.load(swc_pointer) @@ -771,6 +770,7 @@ def generate_negative_examples(self): negative_examples : List[dict] List of negative examples collected across all graphs. """ + # Subroutines def add_examples(): """ @@ -924,7 +924,7 @@ def __init__( is_multimodal=False, modality=None, sampler=None, - use_shuffle=True + use_shuffle=True, ): """ Instantiates a MergeSiteDataLoader object. @@ -970,11 +970,11 @@ def __iter__(self): for start in range(0, len(idxs), self.batch_size): end = min(start + self.batch_size, len(idxs)) if self.is_multimodal and self.modality == "graph": - yield self._load_image_graph_batch(idxs[start: end]) + yield self._load_image_graph_batch(idxs[start:end]) elif self.is_multimodal and self.modality == "pointcloud": - yield self._load_image_pc_batch(idxs[start: end]) + yield self._load_image_pc_batch(idxs[start:end]) else: - yield self._load_image_batch(idxs[start: end]) + yield self._load_image_batch(idxs[start:end]) def _load_image_batch(self, batch_idxs): """ @@ -1084,9 +1084,7 @@ def _load_image_graph_batch(self, idxs): h.append(h_i) x.append(x_i) edge_index.append(edge_index_i) - batches.append( - torch.full((n_i,), i, dtype=torch.long) - ) + batches.append(torch.full((n_i,), i, dtype=torch.long)) node_offset += n_i @@ -1100,7 +1098,7 @@ def _load_image_graph_batch(self, idxs): batch = ml_util.TensorDict( { "img": ml_util.to_tensor(patches), - "graph": (h, x, edge_index, batches) + "graph": (h, x, edge_index, batches), } ) return batch, ml_util.to_tensor(targets) diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index 73bb319d..474ea3a9 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -250,7 +250,7 @@ def __init__( prefetch=64, segmentation_path=None, subgraph_radius=100, - use_new_mask=False + use_new_mask=False, ): # Call parent class super().__init__() @@ -317,7 +317,9 @@ def find_fragments_to_search(self): for nodes in nx.connected_components(self.graph): # Compute path length node = util.sample_once(list(nodes)) - length = self.graph.cable_length(max_depth=self.min_size, root=node) + length = self.graph.cable_length( + max_depth=self.min_size, root=node + ) # Check if path length satisfies threshold if length > self.min_size: @@ -431,7 +433,7 @@ def __init__( segmentation_path=None, step_size=10, subgraph_radius=100, - use_new_mask=False + use_new_mask=False, ): # Call parent class super().__init__( @@ -445,7 +447,7 @@ def __init__( prefetch=prefetch, segmentation_path=segmentation_path, subgraph_radius=subgraph_radius, - use_new_mask=use_new_mask + use_new_mask=use_new_mask, ) # Instance attributes @@ -537,7 +539,13 @@ def _get_batch(self, nodes, img, offset): patch_centers = self.get_patch_centers(nodes) - offset # Populate batch array - batch = np.empty((len(nodes), 2,) + self.patch_shape) + batch = np.empty( + ( + len(nodes), + 2, + ) + + self.patch_shape + ) for i, center in enumerate(patch_centers): s = img_util.get_slices(center, self.patch_shape) batch[i, 0, ...] = img_util.normalize(img[s]) @@ -550,7 +558,13 @@ def _get_multimodal_batch(self, nodes, img, offset): patch_centers = self.get_patch_centers(nodes) - offset # Populate batch array - patches = np.empty((len(nodes), 2,) + self.patch_shape) + patches = np.empty( + ( + len(nodes), + 2, + ) + + self.patch_shape + ) point_clouds = np.empty((len(nodes), 3, 3600), dtype=np.float32) for i, (node, center) in enumerate(zip(nodes, patch_centers)): s = img_util.get_slices(center, self.patch_shape) @@ -610,7 +624,7 @@ def __init__( prefetch=128, segmentation_path=None, subgraph_radius=100, - use_new_mask=False + use_new_mask=False, ): # Call parent class super().__init__( @@ -623,7 +637,7 @@ def __init__( prefetch=prefetch, segmentation_path=segmentation_path, subgraph_radius=subgraph_radius, - use_new_mask=use_new_mask + use_new_mask=use_new_mask, ) # Instance attributes diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index 92b624fd..37382210 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -85,7 +85,8 @@ def __init__( self.reset_proposals() self.proposal_generator = ProposalGenerator( - self, max_proposals_per_leaf=max_proposals_per_leaf, + self, + max_proposals_per_leaf=max_proposals_per_leaf, ) # Graph Loader @@ -138,7 +139,7 @@ def generate_proposals( self, search_radius, allow_nonleaf_proposals=False, - min_size_with_proposals=0 + min_size_with_proposals=0, ): """ Generates proposals from leaf nodes. diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index 32fc7b4d..eb75f46d 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -330,7 +330,7 @@ def remove_soma_merges(self): self.relabel_nodes() results = [ f"# Soma Fragments: {len(self.soma_centroids)}", - f"# Soma Merges: {n_soma_merges}" + f"# Soma Merges: {n_soma_merges}", ] return "\n".join(results) diff --git a/src/neuron_proofreader/split_proofreading/proposal_generation.py b/src/neuron_proofreader/split_proofreading/proposal_generation.py index 4f406673..729597cb 100644 --- a/src/neuron_proofreader/split_proofreading/proposal_generation.py +++ b/src/neuron_proofreader/split_proofreading/proposal_generation.py @@ -56,7 +56,7 @@ def __call__( self, initial_radius, allow_nonleaf_proposals=False, - min_size_with_proposals=0 + min_size_with_proposals=0, ): """ Generates edge proposals between fragments within the given search diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index cef95a29..63bab677 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -62,7 +62,7 @@ def __init__( brightness_clip=brightness_clip, patch_shape=patch_shape, padding=padding, - ), + ), ] def __call__(self, subgraph): @@ -329,7 +329,6 @@ def create_segment_mask(self, proposal, shape, offset): img_util.annotate_voxels(mask, voxels, val=0.25) visited.add(frozenset({i, j})) return mask - def read_image(self, center, shape): """ diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index b3bfcc7a..ed84f02b 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -175,8 +175,12 @@ def __call__( while self.dataset.proposals: # Generate predictons cnt += 1 - self.log(f"\nThreshold={new_threshold} w/ only_leaf2leaf={only_leaf2leaf}") - preds = self.predict_proposals(suffix=f"{name}_round={cnt}_threshold={new_threshold}") + self.log( + f"\nThreshold={new_threshold} w/ only_leaf2leaf={only_leaf2leaf}" + ) + preds = self.predict_proposals( + suffix=f"{name}_round={cnt}_threshold={new_threshold}" + ) # Merge accetped proposals cur_threshold = new_threshold diff --git a/src/neuron_proofreader/utils/graph_util.py b/src/neuron_proofreader/utils/graph_util.py index f01c0d50..fb77eda2 100644 --- a/src/neuron_proofreader/utils/graph_util.py +++ b/src/neuron_proofreader/utils/graph_util.py @@ -80,7 +80,9 @@ def __init__( self.node_spacing = node_spacing self.prefetch = prefetch self.prune_depth = prune_depth - self.swc_reader = swc_util.Reader(anisotropy, min_cable_length, verbose) + self.swc_reader = swc_util.Reader( + anisotropy, min_cable_length, verbose + ) self.verbose = verbose def __call__(self, swc_pointer): @@ -146,7 +148,7 @@ def load(self, swc_dict): subgraph. """ # Build graph - graph = swc_util.to_graph(swc_dict, set_attrs=True) + graph = swc_util.to_graph(swc_dict) prune_branches(graph, self.prune_depth) # Extract irreducible components (if applicable) diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 419ecaab..903c784c 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -1,5 +1,5 @@ """ -Created on Wed June 5 16:00:00 2023 +Created on Thu May 21 12:00:00 2026 @author: Anna Grim @email: anna.grim@alleninstitute.org @@ -14,15 +14,17 @@ "z" (float): z coordinate "pid" (int): node ID of parent -Note: Each uncommented line in an SWC file corresponds to a node and contains - these attributes in the same order. +Note: Each line in an SWC file corresponds to a node and contains these + attributes in the same order. """ +from botocore import UNSIGNED +from botocore.client import Config from collections import deque from concurrent.futures import ( + as_completed, ProcessPoolExecutor, ThreadPoolExecutor, - as_completed, ) from google.auth.exceptions import RefreshError, TransportError from google.cloud import storage @@ -31,6 +33,7 @@ from zipfile import ZipFile import ast +import boto3 import networkx as nx import numpy as np import os @@ -41,10 +44,10 @@ class Reader: """ Class that reads SWC files stored in a (1) local directory, (2) local ZIP - archive, or (3) GCS directory, (4) GCS directory of ZIP archives. + archive, and (3) local directory of ZIP archives. """ - def __init__(self, anisotropy=(1.0, 1.0, 1.0), min_size=0, verbose=True): + def __init__(self, anisotropy=(1.0, 1.0, 1.0), verbose=True): """ Initializes a Reader object that reads SWC files. @@ -52,149 +55,138 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0), min_size=0, verbose=True): ---------- anisotropy : Tuple[float], optional Image to physical coordinates scaling factors to account for the - anisotropy of the microscope. Default is [1.0, 1.0, 1.0]. - min_size : int, optional - Threshold on the number nodes in SWC files that are parsed and - returned. Default is 0. - verbose : bool - Indication of whether to display a progress bar during loading. - Default is True. + anisotropy of the microscope. Default is (1.0, 1.0, 1.0). + verbose : bool, optional + Indication of whether to display a progress bar. Default is True. """ self.anisotropy = anisotropy - self.min_size = min_size self.verbose = verbose + # --- Read Data --- def __call__(self, swc_pointer): """ - Reads SWC files located at the path specified by "swc_pointer". + Loads SWC files based on the type pointer provided. Parameters ---------- - swc_pointer : str or List[str] + swc_pointer : str Object that points to SWC files to be read, must be one of: - file_path: Path to single SWC file - dir_path: Path to local directory with SWC files - zip_path: Path to local ZIP with SWC files - zip_dir_path: Path to local directory of ZIPs with SWC files + - s3_dir_path: Path to S3 prefix with SWC files - gcs_dir_path: Path to GCS prefix with SWC files - gcs_zip_dir_path: Path to GCS prefix with ZIPs of SWC files - - path_list: List of paths to local SWC files Returns ------- Deque[dict] - List of dictionaries whose keys and values are the attribute names - and values from the SWC files. Each dictionary contains the - following items: + Dictionaries whose keys and values are the attribute names and + values from the SWC files. Each dictionary contains the following: + items: - "id": unique identifier of each node in an SWC file. - "pid": parent ID of each node. - "radius": radius value corresponding to each node. - "xyz": coordinate corresponding to each node. - - "soma_nodes": nodes with soma type. - - "swc_name": name of SWC file, minus the ".swc". + - "filename": filename of SWC file + - "swc_id": name of SWC file, minus the ".swc". """ - # List of paths to SWC files - if isinstance(swc_pointer, list): - return self.read_from_paths(swc_pointer) - # Directory containing... if os.path.isdir(swc_pointer): - # ZIP archives with SWC files + # Local ZIP archives with SWC files paths = util.list_paths(swc_pointer, extension=".zip") if len(paths) > 0: - return self.read_from_zips(swc_pointer) + return self.read_zips(swc_pointer, self.read_zip) - # SWC files - paths = util.list_paths(swc_pointer, extension=".swc") + # Local SWC files + paths = util.read_paths(swc_pointer, extension=".swc") if len(paths) > 0: - return self.read_from_paths(paths) + return self.read_swcs(paths, self.read_swc) - raise Exception(f"Directory is invalid - {swc_pointer}") + raise Exception("Directory is Invalid!") # Path to... if isinstance(swc_pointer, str): - # Single SWC file in GCS - if util.is_gcs_path(swc_pointer) and swc_pointer.endswith(".swc"): - bucket_name, path = util.parse_cloud_path(swc_pointer) - return [self.read_from_gcs_swc(bucket_name, path)] - - # GCS directory - if util.is_gcs_path(swc_pointer): - return self.read_from_gcs(swc_pointer) + # Cloud GCS/S3 storage + if util.is_gcs_path(swc_pointer) or util.is_s3_path(swc_pointer): + return self.read_from_cloud(swc_pointer) - # ZIP archive with SWC files + # Local ZIP archive with SWC files if swc_pointer.endswith(".zip"): - return self.read_from_zip(swc_pointer) + return self.read_zip(swc_pointer) - # Single SWC file + # Local path to single SWC file if swc_pointer.endswith(".swc"): - return self.read_from_path(swc_pointer) + return self.read_swc(swc_pointer) - raise Exception(f"Path is invalid {swc_pointer}") + raise Exception("Path is Invalid!") - raise Exception(f"SWC Pointer is invalid {swc_pointer}") + raise Exception("SWC Pointer is Invalid!") - # --- Read subroutines --- - def read_from_paths(self, swc_paths): + def read_swc(self, path): """ - Reads a list of SWC files stored on the local machine. + Reads a single SWC file. Paramters --------- - swc_paths : List[str] - Paths to SWC files stored on the local machine. + path : str + Path to SWC file. Returns ------- - swc_dicts : Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + dict + Dictionary whose keys and values are the attribute names and + values from an SWC file. """ - with ProcessPoolExecutor() as executor: - # Assign processes - processes = list() - for path in swc_paths: - processes.append(executor.submit(self.read_from_path, path)) + content = util.read_txt(path).splitlines() + filename = os.path.basename(path) + return self.parse(content, filename) - # Store results - swc_dicts = deque() - for process in as_completed(processes): - result = process.result() - if result: - swc_dicts.append(result) - return swc_dicts - - def read_from_path(self, path): + def read_swcs(self, swc_paths, read_fn): """ - Reads a single SWC file stored on the local machine. + Reads SWC files stored in a GCS or S3 bucket. - Paramters - --------- - path : str - Path to SWC file stored on the local machine. + Parameters + ---------- + bucket_name : str + Name of bucket containing SWC files. + swc_paths : List[str] + List of paths to SWC files to be read. Returns ------- - swc_dict : dict + swc_dicts : Deque[dict] Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - content = util.read_txt(path) - if len(content) > self.min_size - 10: - swc_dict = self.parse(content) - swc_dict["swc_name"] = get_swc_name(path) - return swc_dict - else: - return False + with ThreadPoolExecutor() as executor: + # Assign threads + threads = set() + for path in swc_paths: + threads.add(executor.submit(read_fn, path)) - def read_from_zips(self, zip_dir): + # Store results + swc_dicts = deque() + pbar = self.manual_progress_bar(len(threads)) + for thread in as_completed(threads): + result = thread.result() + if result: + swc_dicts.append(result) + if self.verbose: + pbar.update(1) + return swc_dicts + + def read_zips(self, zip_paths, read_fn): """ - Reads a directory containing ZIP archives with SWC files. + Reads SWC files stored in ZIP archives. Parameters ---------- - zip_dir : str - Path to directory containing ZIP archives with SWC files. + bucket_name : str + Name of bucket containing SWC files. + zip_paths : List[str] + Paths to ZIP archives containing SWC files to be read. Returns ------- @@ -202,273 +194,251 @@ def read_from_zips(self, zip_dir): Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - # Initializations - zip_names = [f for f in os.listdir(zip_dir) if f.endswith(".zip")] - iterator = zip_names - if self.verbose: - pbar = tqdm(iterator, desc="Read SWCs") - - # Main + pbar = self.manual_progress_bar(len(zip_paths)) with ProcessPoolExecutor() as executor: - # Assign threads - processes = list() - for f in iterator: - zip_path = os.path.join(zip_dir, f) - processes.append(executor.submit(self.read_from_zip, zip_path)) + # Assign processes + futures = {executor.submit(read_fn, path) for path in zip_paths} # Store results swc_dicts = deque() - for process in as_completed(processes): - swc_dicts.extend(process.result()) + for process in as_completed(futures): + try: + swc_dicts.extend(process.result()) + except RefreshError: + pass + if self.verbose: pbar.update(1) return swc_dicts - def read_from_zip(self, zip_path): + def read_zip(self, zip_path): """ - Reads SWC files from a ZIP archive stored on the local machine. + Reads SWC files from a ZIP archive. Paramters --------- - str : str - Path to a ZIP archive on the local machine. + zip_path : str + Path to ZIP archive. Returns ------- - swc_dicts : Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + swc_dicts : Deque[dict] + Dictionaries whose keys and values are the attribute names and + values from an SWC file. """ with ThreadPoolExecutor() as executor: - with ZipFile(zip_path, "r") as zf: - # Submit threads - threads = list() - for f in [f for f in zf.namelist() if f.endswith(".swc")]: - threads.append( - executor.submit(self.read_from_zipped_file, zf, f) - ) + # Assign threads + threads = set() + zf = ZipFile(zip_path, "r") + for name in [f for f in zf.namelist() if f.endswith(".swc")]: + threads.add(executor.submit(self.read_zipped_swc, zf, name)) - # Store results - swc_dicts = deque() - for thread in as_completed(threads): - swc_dict = thread.result() - if swc_dict: - swc_dicts.append(swc_dict) + # Store results + swc_dicts = deque() + for thread in as_completed(threads): + result = thread.result() + if result: + swc_dicts.append(result) return swc_dicts - def read_from_zipped_file(self, zip_file, path): + def read_zipped_swc(self, zipfile, path): """ - Reads SWC file stored in a ZIP archive. + Reads an SWC file stored in a ZIP archive. Parameters ---------- - zip_file : ZipFile - ZIP archive containing SWC file to be read. + zipfile : ZipFile + ZIP archive containing SWC files. path : str - Path to SWC file to be read. + Path to SWC file. Returns ------- - swc_dict : dict - Dictionaries whose keys and values are the attribute names and + dict + Dictionary whose keys and values are the attribute names and values from an SWC file. """ - content = util.read_zip(zip_file, path).splitlines() - if len(content) > self.min_size - 10: - swc_dict = self.parse(content) - swc_dict["swc_name"] = get_swc_name(path) - return swc_dict - else: - return False + content = util.read_zip(zipfile, path).splitlines() + filename = os.path.basename(path) + return self.parse(content, filename) - def read_from_gcs(self, gcs_path): + def read_from_cloud(self, path): """ - Reads SWC files stored in a GCS bucket. + Reads SWC files stored in a GCS or S3 bucket. Parameters ---------- - gcs_path : str - Path to SWC files located in a GCS bucket. + path : str + Path to location in a GCS or S3 bucket containing SWC files, + must be in the format "{scheme}://{bucket_name}/{prefix}". Returns ------- - Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + Deque[dict] + Dictionaries whose keys and values are the attribute names and + values from an SWC file. """ - # List filenames - bucket_name, prefix = util.parse_cloud_path(gcs_path) - swc_paths = util.list_gcs_filenames(bucket_name, prefix, ".swc") - zip_paths = util.list_gcs_filenames(bucket_name, prefix, ".zip") + # Extact info + assert util.is_s3_path(path) or util.is_gcs_path(path) + use_s3 = util.is_s3_path(path) + + # List paths + swc_paths = util.list_cloud_paths(path, ".swc") + zip_paths = util.list_cloud_paths(path, ".zip") # Call reader - if len(swc_paths) > 0: - return self.read_from_gcs_swcs(bucket_name, swc_paths) - if len(zip_paths) > 0: - return self.read_from_gcs_zips(bucket_name, zip_paths) + if swc_paths: + return self.read_swcs(swc_paths, self.read_swc) + elif zip_paths: + read_fn = self.read_s3_zip if use_s3 else self.read_gcs_zip + return self.read_zips(zip_paths, read_fn) - # Error - raise Exception(f"GCS Pointer is invalid {gcs_path}") + raise Exception(f"SWC Pointer is invalid {path}") - def read_from_gcs_swcs(self, bucket_name, swc_paths): + def read_gcs_swc(self, path): """ - Reads SWC files stored in a GCS bucket. + Reads a single SWC file stored in a GCS bucket. Parameters ---------- - bucket_name : str - Name of GCS bucket containing SWC files. - swc_paths : List[str] - Paths to SWC files. + path : List[str] + Path to SWC file to be read. Returns ------- - swc_dicts : Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + Deque[dict] + Dictionaries whose keys and values are the attribute names and + values from an SWC file. """ - if self.verbose: - pbar = tqdm(total=len(swc_paths), desc="Read SWCs") - - with ThreadPoolExecutor() as executor: - # Assign threads - threads = list() - for path in swc_paths: - threads.append( - executor.submit(self.read_from_gcs_swc, bucket_name, path) - ) + # Initialize cloud reader + bucket_name, subpath = util.parse_cloud_path(path) + bucket = storage.Client().bucket(bucket_name) + blob = bucket.blob(subpath) - # Store results - swc_dicts = deque() - for thread in as_completed(threads): - result = thread.result() - if result: - swc_dicts.append(result) - if self.verbose: - pbar.update(1) - return swc_dicts + # Parse swc contents + content = blob.download_as_text().splitlines() + filename = os.path.basename(subpath) + return self.parse(content, filename) - def read_from_gcs_swc(self, bucket_name, path): + def read_gcs_zip(self, path): """ - Reads a single SWC file stored in a GCS bucket. + Reads SWC files stored in a ZIP archive downloaded from a GCS + bucket. Parameters ---------- - bucket_name : str - Name of GCS bucket containing SWC files. - swc_path : str - Path to SWC file to be read. + path : str + Path to ZIP archive containing SWC files to be read. Returns ------- - swc_dict : dict + swc_dicts : Deque[dict] Dictionaries whose keys and values are the attribute names and values from an SWC file. """ - # Initialize cloud reader - client = storage.Client() - bucket = client.bucket(bucket_name) - blob = bucket.blob(path) + # Download ZIP + bucket_name, path = util.parse_cloud_path(path) + bucket = storage.Client().bucket(bucket_name) + try: + zip_content = bucket.blob(path).download_as_bytes() + except TransportError: + print(f"Failed to read {path}!") + return deque() - # Parse swc contents - content = blob.download_as_text().splitlines() - if len(content) > self.min_size - 10: - swc_dict = self.parse(content) - swc_dict["swc_name"] = get_swc_name(path) - return swc_dict - else: - return False + # Parse ZIP contents + swc_dicts = deque() + with ZipFile(BytesIO(zip_content), "r") as zf: + with ThreadPoolExecutor() as executor: + # Assign threads + threads = set() + for name in zf.namelist(): + if self.confirm_read(name): + threads.add( + executor.submit(self.read_zipped_swc, zf, name) + ) + + # Process results + for thread in as_completed(threads): + result = thread.result() + if result: + swc_dicts.append(result) + return swc_dicts - def read_from_gcs_zips(self, bucket_name, zip_paths): + def read_s3_zip(self, path): """ - Reads SWC files from ZIP archives stored in a GCS bucket. + Reads SWC files stored in a ZIP archive downloaded from an S3 + bucket. Parameters ---------- - bucket_name : str - Name of GCS bucket containing SWC files. - zip_paths : List[str] - Paths to ZIP archives in a GCS bucket. + path : str + Path to ZIP archive containing SWC files to be read. Returns ------- - swc_dicts : Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + swc_dicts : Deque[dict] + Dictionaries whose keys and values are the attribute names and + values from an SWC file. """ - # Initializations - batch_size = 1000 - if self.verbose: - pbar = tqdm(total=len(zip_paths), desc="Read SWCs") - - # Main - swc_dicts = deque() - with ProcessPoolExecutor() as executor: - for i in range(0, len(zip_paths), batch_size): - # Assign processes - processes = list() - for zip_path in zip_paths[i: i + batch_size]: - processes.append( - executor.submit( - self.read_from_gcs_zip, bucket_name, zip_path - ) + # Initialize cloud reader + bucket, key = util.parse_cloud_path(path) + s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) + zip_content = s3.get_object(Bucket=bucket, Key=key)["Body"].read() + + # Parse ZIP + with ZipFile(BytesIO(zip_content), "r") as zf: + with ThreadPoolExecutor() as executor: + # Assign threads + threads = set() + for name in zf.namelist(): + threads.add( + executor.submit(self.read_zipped_swc, zf, name) ) # Store results - for process in as_completed(processes): - try: - swc_dicts.extend(process.result()) - except RefreshError: - pass - if self.verbose: - pbar.update(1) + swc_dicts = deque() + for thread in as_completed(threads): + result = thread.result() + if result: + swc_dicts.append(result) return swc_dicts - def read_from_gcs_zip(self, bucket_name, zip_path, filenames=None): + # -- Process Text --- + def iterator(self, iterator): """ - Reads SWC files stored in a ZIP archive downloaded from a cloud - bucket. + Gets an iterator that optionally displays a progress bar. Parameters ---------- - bucket_name : str - Name of GCS bucket containing SWC files. - zip_path : str - Path to ZIP archive to be read. - filenames : None or List[str], optional - Filenames to be read if provided. Default is None. + iterator : iterable + Object to be iterated over. Returns ------- - swc_dicts : Dequeue[dict] - List of dictionaries whose keys and values are the attribute - names and values from an SWC file. + tqdm.tqdm + Iterator that is optionally wrapped in a progress bar. """ - try: - # Download zip - client = storage.Client() - bucket = client.bucket(bucket_name) - zip_content = bucket.blob(zip_path).download_as_bytes() - except TransportError: - print(f"Failed to read {zip_path}!") - return deque() + return tqdm(iterator, desc="Read SWCs") if self.verbose else iterator - # Process files - swc_dicts = deque() - with ZipFile(BytesIO(zip_content), "r") as zip_file: - filenames = zip_file.namelist() if filenames is None else filenames - for filename in filenames: - result = self.read_from_zipped_file(zip_file, filename) - if result: - swc_dicts.append(result) - return swc_dicts + def manual_progress_bar(self, total): + """ + Gets progress bar that needs to be updated manually. + + Parameters + ---------- + total : int + Size of progress bar. + + Returns + ------- + tqdm.tqdm + Iterator that is optionally wrapped in a progress bar. + """ + return tqdm(total=total, desc="Read SWCs") if self.verbose else None - # --- Process content --- - def parse(self, content): + def parse(self, content, filename): """ Parses an SWC file to extract the content which is stored in a dict. - Note that node_ids from SWC are reindex from 0 to n-1 where n is the - number of nodes in the SWC file. Parameters ---------- @@ -477,34 +447,36 @@ def parse(self, content): Returns ------- - swc_dict : dict - Dictionaries whose keys and values are the attribute names - and values from an SWC file. + dict + Dictionary whose keys and values are the attribute names and + values from an SWC file. """ # Initializations + swc_name, _ = os.path.splitext(filename) content, offset = self.process_content(content) - swc_dict = { - "id": np.zeros((len(content)), dtype=int), - "radius": np.zeros((len(content)), dtype=np.float16), - "pid": np.zeros((len(content)), dtype=int), - "xyz": np.zeros((len(content), 3), dtype=np.float32), - "soma_nodes": set(), - } - - # Parse content - for i, line in enumerate(content): - parts = line.split() - swc_dict["id"][i] = parts[0] - swc_dict["radius"][i] = float(parts[-2]) - swc_dict["pid"][i] = parts[-1] - swc_dict["xyz"][i] = self.read_xyz(parts[2:5], offset) - if int(parts[1]) == 1: - swc_dict["soma_nodes"].add(parts[0]) - - # Convert radius from nanometers to microns - if swc_dict["radius"][0] > 100: - swc_dict["radius"] /= 1000 - return swc_dict + if len(content) > 0: + swc_dict = { + "id": np.zeros((len(content)), dtype=int), + "pid": np.zeros((len(content)), dtype=int), + "radius": np.zeros((len(content)), dtype=float), + "xyz": np.zeros((len(content), 3), dtype=np.int32), + "swc_name": swc_name, + } + + # Parse content + for i, line in enumerate(content): + parts = line.split() + swc_dict["id"][i] = parts[0] + swc_dict["pid"][i] = parts[-1] + swc_dict["radius"][i] = float(parts[-2]) + swc_dict["xyz"][i] = self.read_coordinate(parts[2:5], offset) + + # Convert radius from nanometers to microns + if swc_dict["radius"][0] > 100: + swc_dict["radius"] /= 1000 + return swc_dict + else: + return None def process_content(self, content): """ @@ -520,33 +492,33 @@ def process_content(self, content): Returns ------- content : List[str] - List of strings representing the lines of text starting from the - line immediately after the last commented line. - offset : List[float] - Offset used to shift coordinates. + Lines from an SWC file after comments. + offset : Tuple[int] + Offset used to shift coordinate. """ offset = (0, 0, 0) for i, line in enumerate(content): if line.startswith("# OFFSET"): - offset = self.read_xyz(line.split()[2:5]) - if not line.startswith("#") and len(line) > 0: + parts = line.split() + offset = self.read_coordinate(parts[2:5]) + if not line.startswith("#") and len(line.strip()) > 0: return content[i:], offset - def read_xyz(self, xyz_str, offset=(0, 0, 0)): + def read_coordinate(self, xyz_str, offset=(0, 0, 0)): """ - Reads a 3D coordinate from a string and transforms it. + Reads a coordinate from a string and converts it to voxel coordinates. Parameters ---------- xyz_str : str - Coordinate stored as a str. - offset : List[float], optional - Shift applied to coordinate. Default is (0, 0, 0). + Coordinate stored as a string. + offset : Tuple[int] + Offset of coordinates in SWC file. Default is (0, 0, 0). Returns ------- - List[float] - Coordinate of node from an SWC file. + Tuple[int] + xyz coordinates of an entry from an SWC file. """ iterator = zip(self.anisotropy, xyz_str, offset) return [a * (float(s) + o) for a, s, o in iterator] @@ -575,10 +547,10 @@ def write_points( radius : float, optional Radius to be used in SWC file. Default is 10. """ - zip_writer = ZipFile(zip_path, write_mode) + zf = ZipFile(zip_path, write_mode) for i, xyz in enumerate(points): filename = prefix + str(i + 1) + ".swc" - to_zipped_point(zip_writer, filename, xyz, color=color, radius=radius) + to_zipped_point(zf, filename, xyz, color=color, radius=radius) def to_zipped_point(zip_writer, filename, xyz, color=None, radius=5): @@ -654,17 +626,14 @@ def get_swc_name(path): return name -def to_graph(swc_dict, set_attrs=False): +def to_graph(swc_dict): """ - Converts an SWC dict to a NetworkX graph with reindexed nodes. + Converts an SWC dictionary to a NetworkX graph with reindexed nodes. Parameters ---------- swc_dict : dict Contents of an SWC file. - set_attrs : bool, optional - Indication of whether to set "xyz" and "radius" as graph-level - attributes. Default is False. Returns ------- @@ -679,9 +648,10 @@ def to_graph(swc_dict, set_attrs=False): ] # Build graph with reindexed edges - graph = nx.Graph(swc_name=swc_dict["swc_name"]) + graph = nx.Graph( + swc_name=swc_dict["swc_name"], + radius=swc_dict["radius"], + xyz=swc_dict["xyz"], + ) graph.add_edges_from(edges) - if set_attrs: - graph.graph["xyz"] = swc_dict["xyz"] - graph.graph["radius"] = swc_dict["radius"] return graph diff --git a/src/neuron_proofreader/utils/util.py b/src/neuron_proofreader/utils/util.py index d0abe0ce..180cac62 100644 --- a/src/neuron_proofreader/utils/util.py +++ b/src/neuron_proofreader/utils/util.py @@ -223,20 +223,25 @@ def read_json(path): def read_txt(path): """ - Reads txt file located at the given path. + Reads txt file at the given path. Parameters ---------- path : str - Path to txt file to be read. + Path to txt file. Returns ------- str - Contents of txt file. + Text from the txt file. """ - with open(path, "r") as f: - return f.read().splitlines() + if is_s3_path(path): + return read_txt_from_s3(path) + elif is_gcs_path(path): + return read_txt_from_gcs(path) + else: + with open(path, "r") as f: + return f.read() def read_zip(zip_file, path): @@ -328,6 +333,58 @@ def write_txt(path, contents): f.close() +# --- Cloud Utils --- +def list_cloud_paths(path, extension=""): + """ + Lists all files in a GCS/S3 bucket with the given extension. + + Parameters + ---------- + path : str + Path to cloud prefix to be searched, must be in the format: + f"{scheme}://{bucket_name}/{prefix}". + extension : str, optional + File extension of filenames to be listed. Default is an empty string. + + Returns + ------- + List[str] + Filenames stored at the GCS path with the given extension. + """ + assert is_gcs_path(path) or is_s3_path(path) + bucket_name, prefix = parse_cloud_path(path) + list_fn = list_gcs_paths if is_gcs_path(path) else list_s3_paths + return list_fn(bucket_name, prefix, extension=extension) + + +def parse_cloud_path(path): + """ + Parses a cloud storage path into its bucket name and key/prefix. Supports + paths of the form: "{scheme}://bucket_name/prefix" or without a scheme. + + Parameters + ---------- + path : str + Path to be parsed. + + Returns + ------- + bucket_name : str + Name of the bucket. + prefix : str + Cloud prefix. + """ + # Remove s3:// or gs:// if present + if path.startswith("s3://") or path.startswith("gs://"): + path = path[len("s3://"):] + + # Split path + parts = path.split("/", 1) + bucket_name = parts[0] + prefix = parts[1] if len(parts) > 1 else "" + return bucket_name, prefix + + # --- GCS utils --- def check_gcs_file_exists(bucket_name, path): """ @@ -368,14 +425,14 @@ def is_gcs_path(path): return path.startswith("gs://") -def list_gcs_filenames(bucket_name, prefix, extension=""): +def list_gcs_paths(bucket_name, prefix, extension=""): """ - Lists all files in a GCS bucket with the given extension. + Lists paths at a GCS prefix with the given extension. Parameters ---------- bucket_name : str - Name of bucket to be searched. + Name of bucket containing prefix. prefix : str Path to location within bucket to be searched. extension : str, optional @@ -384,11 +441,14 @@ def list_gcs_filenames(bucket_name, prefix, extension=""): Returns ------- List[str] - Filenames stored at the GCS path with the given extension. + Paths under the GCS prefix with the given extension. """ bucket = storage.Client().bucket(bucket_name) - blobs = bucket.list_blobs(prefix=prefix) - return [blob.name for blob in blobs if extension in blob.name] + paths = list() + for name in [b.name for b in bucket.list_blobs(prefix=prefix)]: + if extension in name: + paths.append(os.path.join(f"gs://{bucket_name}", name)) + return paths def list_gcs_subdirectories(bucket_name, prefix): @@ -447,6 +507,25 @@ def read_json_from_gcs(bucket_name, blob_path): return json.loads(blob.download_as_text()) +def read_txt_from_gcs(path): + """ + Reads a txt file stored in a GCS bucket. + + Parameters + ---------- + path : str + Path to txt file to be read. + + Returns + ------- + str + Contents of txt file. + """ + bucket_name, subpath = parse_cloud_path(path) + bucket = storage.Client().bucket(bucket_name) + return bucket.blob(subpath).download_as_text() + + # --- S3 utils --- def is_s3_path(path): """ @@ -465,35 +544,58 @@ def is_s3_path(path): return path.startswith("s3://") -def list_s3_prefixes(bucket_name, prefix): +def list_s3_paths(bucket_name, prefix, extension=""): """ - Lists all immediate subdirectories of a given S3 path (prefix). + Lists all object keys in a public S3 bucket under a given prefix, + optionally filters by file extension. Parameters - ----------- + ---------- bucket_name : str - Name of the S3 bucket to search. + Name of the S3 bucket. prefix : str - S3 prefix to search within. + Prefix to search under. + extension : str, optional + File extension to filter by. Default is an empty string. - Returns: - -------- - List[str] - Immediate subdirectories under the specified prefix. + Returns + ------- + paths : List[str] + S3 object keys that match the prefix and extension filter. """ - # Check prefix is valid - if not prefix.endswith("/"): - prefix += "/" + # Create an anonymous client for public buckets + s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) + response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix) + + # List all objects under the prefix + paths = list() + if "Contents" in response: + for obj in response["Contents"]: + filename = obj["Key"] + if filename.endswith(extension): + path = os.path.join(f"s3://{bucket_name}", filename) + paths.append(path) + return paths + - # Call the list_objects_v2 API +def read_txt_from_s3(path): + """ + Reads a txt file stored in an S3 bucket. + + Parameters + ---------- + path : str + Path to txt file to be read. + + Returns + ------- + str + Contents of txt file. + """ + bucket_name, subpath = parse_cloud_path(path) s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) - response = s3.list_objects_v2( - Bucket=bucket_name, Prefix=prefix, Delimiter="/" - ) - if "CommonPrefixes" in response: - return [cp["Prefix"] for cp in response["CommonPrefixes"]] - else: - return list() + obj = s3.get_object(Bucket=bucket_name, Key=subpath) + return obj["Body"].read().decode("utf-8") def upload_dir_to_s3(dir_path, bucket_name, prefix): @@ -618,34 +720,6 @@ def numpy_to_hashable(arr): return [tuple(item) for item in arr.tolist()] -def parse_cloud_path(path): - """ - Parses a cloud storage path into its bucket name and key/prefix. Supports - paths of the form: "{scheme}://bucket_name/prefix" or without a scheme. - - Parameters - ---------- - path : str - Path to be parsed. - - Returns - ------- - bucket_name : str - Name of the bucket. - prefix : str - Cloud prefix. - """ - # Remove s3:// or gs:// if present - if path.startswith("s3://") or path.startswith("gs://"): - path = path[len("s3://"):] - - # Split path - parts = path.split("/", 1) - bucket_name = parts[0] - prefix = parts[1] if len(parts) > 1 else "" - return bucket_name, prefix - - def sample_once(my_container): """ Samples a single element from the given container. diff --git a/tests/__init__.py b/tests/__init__.py index d0a85479..024d51f2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,2 +1,3 @@ """Init package""" + __version__ = "0.0.0"