From aec65681b2fbe20555d62253f259afa7932cf846 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Wed, 20 May 2026 07:21:28 -0400 Subject: [PATCH 1/2] ENH: Initial implementation of a PhysicsNeMo tutorial --- .pre-commit-config.yaml | 4 +- docs/API_MAP.md | 346 ++++++++++++++++++ pyproject.toml | 3 + tutorials/__init__.py | 0 ...utorial_09_physicsnemo_mesh_stage_model.py | 266 ++++++++++++++ 5 files changed, 617 insertions(+), 2 deletions(-) delete mode 100644 tutorials/__init__.py create mode 100644 tutorials/tutorial_09_physicsnemo_mesh_stage_model.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 34bc0d3..539e1af 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,13 +19,13 @@ repos: - id: mypy # With pass_filenames: false, we must provide an explicit target for mypy. # Otherwise mypy exits with: "Missing target module, package, files, or command." - args: [--config-file=pyproject.toml, src] + args: [--config-file=pyproject.toml, src, tutorials] additional_dependencies: - types-setuptools - scipy-stubs - types-requests pass_filenames: false - files: ^src/ + files: ^(src/|tutorials/) # Regenerate docs/API_MAP.md whenever Python source files change - repo: local diff --git a/docs/API_MAP.md b/docs/API_MAP.md index b70fe02..2bdc798 100644 --- a/docs/API_MAP.md +++ b/docs/API_MAP.md @@ -26,6 +26,352 @@ _Re-run `py utils/generate_api_map.py` whenever public APIs change._ - `def generate_pc_variation(pc_index, std_dev_multiplier=3.0)` (line 155): Generate shape variations along a principal component. +## experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py + +- `def segment_images(src_data_dirs, src_data_files)` (line 57): Segment each image with SegmentHeartSimpleware and save labelmaps. + +## experiments/LongitudinalRegistration/1-finetune_icon.py + +- `def get_segmented_images(src_data_dirs, src_data_files)` (line 84): Segment each image with SegmentHeartSimpleware and save labelmaps. +- `def get_mask_images(src_data_dirs, src_data_files)` (line 117): Get mask images for each image. + +## experiments/LongitudinalRegistration/recon_4d_icon_finetuned.py + +- `def convert_labelmap_to_masks(labelmap_file, output_dir)` (line 82) +- `def register_time_series(reference_image_file, reference_labelmap_file, source_image_dir, source_image_files, segmented_image_files, weights_path)` (line 91) + +## experiments/LongitudinalRegistration/recon_4d_run.py + +- `def register_time_series(reference_image_file, source_image_dir, source_image_files, registration_method)` (line 69) + +## experiments/LongitudinalRegistration/uniGradICON/scripts/prepare_l2r_datasets.py + +- `def generate_oasis_json(data_dir, output_dir)` (line 31): Generate JSON for OASIS brain MRI (unpaired, with segmentations). +- `def generate_lungct_json(data_dir, output_dir)` (line 57): Generate JSON for LungCT (paired). +- `def generate_abdomenmrct_json(data_dir, output_dir)` (line 89): Generate JSON for AbdomenMRCT cross-modality dataset (unpaired, with segmentations). +- `def main()` (line 122) + +## experiments/LongitudinalRegistration/uniGradICON/src/unigradicon/__init__.py + +- **class GradientICONSparse** (line 23) + - `def __init__(self, network, similarity, lmbda, use_label=False, apply_intensity_conservation_loss=False, dice_loss_weight=0.0, loss_function_masking=False)` (line 24) + - `def forward(self, image_A, image_B, label_A=None, label_B=None, mask_A=None, mask_B=None, segmentation_A=None, segmentation_B=None)` (line 35) + - `def compute_jacobian_determinant(self, phi)` (line 254) + - `def dice_loss(self, pred, target, epsilon=1e-06)` (line 277): Compute Dice loss between one-hot encoded prediction and target. + - `def clean(self)` (line 298) +- `def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.LNCC(sigma=5), use_label=False, apply_intensity_conservation_loss=False, dice_loss_weight=0.0, loss_function_masking=False)` (line 324) +- `def make_sim(similarity, sigma=5, mind_radius=2, mind_dilation=2)` (line 340) +- `def get_multigradicon(loss_fn=icon.LNCC(sigma=5), apply_intensity_conservation_loss=False, weights_location=None, dice_loss_weight=0.0, loss_function_masking=False)` (line 350) +- `def get_unigradicon(loss_fn=icon.LNCC(sigma=5), apply_intensity_conservation_loss=False, weights_location=None, dice_loss_weight=0.0, loss_function_masking=False)` (line 369) +- `def get_model_from_model_zoo(model_name='unigradicon', loss_fn=icon.LNCC(sigma=5), apply_intensity_conservation_loss=False, dice_loss_weight=0.0, loss_function_masking=False, weights_location=None)` (line 387) +- `def quantile(arr, q)` (line 407) +- `def apply_mask(image, mask)` (line 412) +- `def preprocess(image, modality='ct', mask=None, ct_window=None, quantile_range=None)` (line 427): Preprocess a medical image for registration. +- `def main()` (line 470) +- `def warp_command()` (line 669) +- `def maybe_cast(img)` (line 711): If an itk image is of a type that can't be used with InterpolateImageFunctions, cast it +- `def compute_jacobian_map_command()` (line 729) + +## experiments/LongitudinalRegistration/uniGradICON/src/unigradicon/finetuning/config.py + +- `def set_reproducibility_seed(seed)` (line 26): Seed Python, NumPy, and PyTorch (CPU and GPU) at startup. +- **class ConfigSections** (line 36) +- **class ExperimentKeys** (line 42) +- **class TrainingKeys** (line 47) +- **class DatasetKeys** (line 69) +- **class DatasetTypes** (line 84) +- **class JsonKeys** (line 89) +- **class ExperimentConfig** (line 134) + - `def from_dict(cls, raw)` (line 139) +- **class TrainingConfig** (line 147) + - `def from_dict(cls, raw)` (line 169) + - `def network_input_shape(self)` (line 178): ``input_shape`` with the [1, 1] batch+channel prefix that +- **class DatasetConfig** (line 185): ``name``/``type``/``json_file`` are YAML-required and validated in + - `def from_dict(cls, raw)` (line 205) +- **class FinetuningConfigSchema** (line 215) + - `def from_dict(cls, raw)` (line 221) +- **class DataLoaderBundle** (line 237) +- **class ConfigValidator** (line 244) + - `def __init__(self, config)` (line 245) + - `def validate(self)` (line 248) +- `def validate_config(config)` (line 345) +- `def load_config(config_path)` (line 349) +- **class DatasetJsonCache** (line 358): Memoizes JSON parsing + path resolution across the multiple validators + - `def __init__(self)` (line 362) + - `def load_content(self, json_path)` (line 366) + - `def load_entries(self, json_path)` (line 378) + - `def entry_field_sets(self, json_path)` (line 418) +- `def required_data_fields(training)` (line 428) +- `def determine_data_fields(dataset_configs, config_dir, json_cache=None)` (line 437) +- `def validate_paired_datasets_have_pairs(schema, config_dir, json_cache)` (line 452): Catches missing-pairs at config-validation time so we don't spend +- `def validate_training_data_compatibility(fields_per_dataset, data_fields)` (line 484) +- `def create_dataset_from_config(dataset_config, input_shape, config_dir='', use_label=False, data_fields=frozenset(), json_cache=None)` (line 519): Build a ``Dataset`` (random pairing) or ``PairedDataset`` (subject-based +- `def create_data_loaders(config_path, config=None)` (line 663): Create training and validation dataloaders from YAML config. + +## experiments/LongitudinalRegistration/uniGradICON/src/unigradicon/finetuning/dataset.py + +- **class Fields** (line 22) +- **class CacheNames** (line 30) +- **class PairKeys** (line 42) +- **class DatasetParams** (line 54): Shared defaults for ``Dataset.__init__`` (via ``_DEFAULTS``) and +- **class DatasetEntry** (line 92) + - `def from_dict(cls, item, idx, dataset_name)` (line 100) +- **class ImageReader** (line 129) + - `def read(self, path)` (line 130) +- **class ImagePreprocessor** (line 138) + - `def __init__(self, reader, input_shape, is_ct, ct_window, quantile_range, modality_map)` (line 139) + - `def preprocess_image(self, path)` (line 155) + - `def preprocess_label_map(self, path)` (line 178) +- **class DatasetCache** (line 185): Cache path is partitioned by a signature hashed from every + - `def __init__(self, dataset_name, cache_dir, enabled, signature)` (line 191) + - `def path(self, cache_name)` (line 214) + - `def load(self, cache_name)` (line 219) + - `def save(self, cache_name, payload)` (line 233) + - `def write_metadata(self, meta)` (line 244): Sidecar ``_meta.json`` describing the params behind the +- **class RandomPairSampler** (line 260) + - `def __init__(self, keys)` (line 261) + - `def sample_partner(self, anchor)` (line 267) +- **class SubjectPairSampler** (line 277): Pair only images sharing the same ``subject_id``. + - `def __init__(self, entries, keys, dataset_name)` (line 280) + - `def sample_partner(self, anchor)` (line 302) +- **class Dataset** (line 306): 3D medical-image registration dataset. + - `def __init__(self, input_shape, name, data, cache_dir=_DEFAULTS.cache_dir, maximum_images=_DEFAULTS.maximum_images, shuffle=_DEFAULTS.shuffle, is_ct=_DEFAULTS.is_ct, ct_window=_DEFAULTS.ct_window, quantile_range=_DEFAULTS.quantile_range, use_cache=_DEFAULTS.use_cache, use_compression=_DEFAULTS.use_compression, use_label=False)` (line 320) + - `def has_segmentation(self)` (line 539) + - `def has_mask(self)` (line 543) + - `def get_image(self, key)` (line 556) +- **class PairedDataset** (line 603): Variant of ``Dataset`` that pairs only within ``subject_id``. JSON + - `def __init__(self, **kwargs)` (line 608) + +## experiments/LongitudinalRegistration/uniGradICON/src/unigradicon/finetuning/finetune.py + +- `def loss_to_dict(loss_object)` (line 33) +- `def augment(batch)` (line 50): Apply random affine augmentation to all spatial data in a batch dict. +- `def finetune_multi(config, data_loader, val_data_loaders_dict, data_fields)` (line 304): Unified finetuning loop. +- `def main(argv=None)` (line 432) + +## experiments/LongitudinalRegistration/uniGradICON/src/unigradicon/finetuning/visualization.py + +- `def render_for_tensorboard(im, max_samples=MAX_DISPLAY_SAMPLES, normalize=True)` (line 42): Prepare image tensor for TensorBoard as an RGB image batch. +- `def segmentation_labels_for_tensorboard(im, max_samples=MAX_DISPLAY_SAMPLES)` (line 59): Convert one-hot or label-map segmentations to integer label images. +- `def labels_to_color_image(labels)` (line 108) +- `def render_segmentation_overlay_for_tensorboard(image, segmentation, alpha=DEFAULT_OVERLAY_ALPHA)` (line 116): Blend a rendered segmentation over its image for TensorBoard. +- `def render_mask_overlay_for_tensorboard(image, mask, alpha=DEFAULT_OVERLAY_ALPHA, color=DEFAULT_MASK_OVERLAY_COLOR)` (line 134): Blend a binary ROI mask over its image as a single solid color. +- `def add_eval_composite_panel(writer, step, moving, fixed, warped, moving_seg=None, fixed_seg=None, warped_seg=None, moving_mask=None, fixed_mask=None, tag='eval')` (line 162): Write a single composite eval panel under ``{tag}`` (default ``"eval"``). + +## experiments/LongitudinalRegistration/uniGradICON/src/unigradicon/unicarl.py + +- `def get_unicarl()` (line 24) +- `def main()` (line 41) + +## experiments/LongitudinalRegistration/uniGradICON/tests/finetuning/conftest.py + +- `def fake_image_reader(monkeypatch)` (line 9): Replace ITK-backed ``ImageReader`` with a deterministic fake so tests + +## experiments/LongitudinalRegistration/uniGradICON/tests/finetuning/test_cache.py + +- `def test_cache_disabled_returns_none_paths(tmp_path)` (line 25) +- `def test_cache_save_and_load_roundtrip(tmp_path)` (line 31) +- `def test_cache_atomic_write_no_partial_file_visible(tmp_path)` (line 41): The save path uses tmp + os.replace; tmp file should not linger. +- `def test_cache_load_returns_none_on_corrupt_file(tmp_path, caplog)` (line 51): A truncated/garbage cache file should not propagate; it should warn +- `def test_cache_write_metadata_skips_when_already_present(tmp_path)` (line 67) +- `def test_dataset_cache_signature_changes_with_quantile_range(tmp_path, fake_image_reader)` (line 84) +- `def test_dataset_cache_signature_stable_for_identical_params(tmp_path, fake_image_reader)` (line 97) +- `def test_dataset_cache_signature_changes_with_input_shape(tmp_path, fake_image_reader)` (line 108) +- `def test_dataset_skip_save_on_clean_cache_hit(tmp_path, fake_image_reader)` (line 119): Second construction with identical params must not rewrite cache files. +- `def test_dataset_rebuilds_when_cache_corrupt(tmp_path, fake_image_reader, caplog)` (line 144): A garbage cache file triggers warning + rebuild, not an exception. +- `def test_dataset_metadata_sidecar_describes_signature(tmp_path, fake_image_reader)` (line 163) +- `def test_dataset_segmentation_cache_written_under_signature_dir(tmp_path, fake_image_reader)` (line 180): Aux maps are persisted as their own ``.trch`` cache file beside images. +- `def test_dataset_segmentation_cache_round_trips(tmp_path, fake_image_reader, monkeypatch)` (line 191): A second construction with identical params loads aux maps from cache +- `def test_dataset_compresses_in_memory_when_enabled(tmp_path, fake_image_reader)` (line 212): With ``use_compression=True`` the in-memory store holds blosc bytes, +- `def test_dataset_skips_compression_by_default(tmp_path, fake_image_reader)` (line 223): The default ``use_compression=False`` keeps tensors uncompressed. +- `def test_dataset_compression_roundtrip_returns_equal_tensor(tmp_path, fake_image_reader)` (line 232): ``get_image`` must return tensors that are equal regardless of whether +- `def test_dataset_compression_signature_partitions_cache(tmp_path, fake_image_reader)` (line 247): Toggling ``use_compression`` produces a different cache signature so +- `def test_dataset_rebuilds_aux_cache_when_keys_outgrow_cache(tmp_path, fake_image_reader)` (line 261): If a previous run wrote an aux cache for a smaller key set (e.g. some +- `def test_dataset_indexing_anchor_is_deterministic(tmp_path, fake_image_reader)` (line 284): ds[i] anchor must be self.keys[i] regardless of RNG state. + +## experiments/LongitudinalRegistration/uniGradICON/tests/finetuning/test_cli.py + +- `def isolate_cwd(tmp_path, monkeypatch)` (line 20): Run inside a temp dir so footsteps' results/ dir doesn't pollute the +- `def mock_finetune_multi(monkeypatch)` (line 33): Capture the (config, data_loader, val_loaders, data_fields) call so the +- `def test_cli_requires_config_arg()` (line 92): ``unigradicon-finetune`` (no args) must fail with a non-zero exit. +- `def test_cli_rejects_unknown_argument()` (line 99) +- `def test_cli_rejects_missing_config_file(tmp_path, isolate_cwd)` (line 104) +- `def test_cli_starts_finetuning_with_valid_config(tmp_path, isolate_cwd, fake_image_reader, mock_finetune_multi)` (line 110): Happy path: a valid YAML drives main() through to ``finetune_multi`` +- `def test_cli_propagates_invalid_config_errors(tmp_path, isolate_cwd, fake_image_reader, mock_finetune_multi)` (line 126): Schema-validation failures must surface as ValueError before training +- `def test_cli_propagates_paired_without_subject_id_error(tmp_path, isolate_cwd, fake_image_reader, mock_finetune_multi)` (line 138): The fast-fail validator catches paired-without-subject_id before any +- `def test_cli_seed_is_applied_before_data_loader_build(tmp_path, isolate_cwd, fake_image_reader, mock_finetune_multi)` (line 156): When 'seed' is in the YAML, set_reproducibility_seed must run before + +## experiments/LongitudinalRegistration/uniGradICON/tests/finetuning/test_config.py + +- `def test_training_config_lambda_alias()` (line 8): The YAML key 'lambda' must map to the dataclass field 'lmbda'. +- `def test_training_config_unknown_keys_silently_dropped_in_from_dict()` (line 14): from_dict drops unknown keys; ConfigValidator handles the warning. +- `def test_training_config_defaults_when_empty_dict()` (line 21) +- `def test_training_config_network_input_shape_property()` (line 28) +- `def test_dataset_config_post_init_rejects_empty_required_fields()` (line 33) +- `def test_dataset_config_from_dict_coerces_lists_to_tuples()` (line 42) +- `def test_validate_rejects_missing_experiment()` (line 64) +- `def test_validate_rejects_missing_datasets()` (line 69) +- `def test_validate_rejects_empty_datasets_list()` (line 74) +- `def test_validate_rejects_invalid_similarity()` (line 82) +- `def test_validate_accepts_valid_similarities_case_insensitive(sim)` (line 88) +- `def test_validate_rejects_negative_learning_rate()` (line 92) +- `def test_validate_rejects_non_positive_int_keys(key)` (line 101) +- `def test_validate_rejects_negative_non_negative_keys(key)` (line 109) +- `def test_validate_rejects_bad_input_shape(shape)` (line 122) +- `def test_validate_rejects_bad_gpus(gpus)` (line 134) +- `def test_validate_accepts_valid_gpus()` (line 139) +- `def test_validate_rejects_bad_samples_per_epoch(spe)` (line 144) +- `def test_validate_accepts_null_or_positive_samples_per_epoch()` (line 149) +- `def test_paired_check_message_when_no_subject_id_field(tmp_path)` (line 154): When no entry has a subject_id at all, the error message names the +- `def test_load_entries_tolerates_null_optional_field(tmp_path)` (line 166): ``{"segmentation": null}`` should be treated as absent, not crash. +- `def test_validate_rejects_zero_dataset_weight()` (line 180) +- `def test_validate_rejects_unknown_dataset_type()` (line 187) +- `def test_validate_warns_on_unknown_training_key(caplog)` (line 194) +- `def test_paired_check_raises_when_no_subject_id_anywhere(tmp_path)` (line 217) +- `def test_paired_check_raises_when_each_subject_has_only_one_image(tmp_path)` (line 227) +- `def test_paired_check_passes_when_at_least_one_subject_has_two_images(tmp_path)` (line 238) +- `def test_paired_check_skips_unpaired_datasets(tmp_path)` (line 249): An 'unpaired' dataset with no subject_id should not trigger the check. +- `def test_required_fields_empty_for_image_only_training()` (line 262) +- `def test_required_fields_includes_segmentation_when_dice_loss_set()` (line 267) +- `def test_required_fields_includes_mask_when_loss_function_masking()` (line 272) +- `def test_required_fields_includes_mask_when_roi_masking()` (line 277) + +## experiments/LongitudinalRegistration/uniGradICON/tests/finetuning/test_pipeline.py + +- `def test_pipeline_rejects_zero_iterations_per_epoch(tmp_path, fake_image_reader)` (line 62): ``samples_per_epoch < batch_size * num_gpus`` would yield zero +- `def test_pipeline_starts_minimal_config(tmp_path, fake_image_reader)` (line 73): Smallest possible valid config: images only, single dataset, 1 epoch. +- `def test_pipeline_train_batch_shape(tmp_path, fake_image_reader)` (line 85): The train loader must yield batches with image_A / image_B at the +- `def test_pipeline_val_batch_shape(tmp_path, fake_image_reader)` (line 103): The validation loader uses batch_size=1 by default. +- `def test_pipeline_with_segmentation_includes_seg_in_batch(tmp_path, fake_image_reader)` (line 116): When dice_loss_weight > 0, batches include segmentation tensors. +- `def test_pipeline_iterates_full_epoch(tmp_path, fake_image_reader)` (line 130): The train loader iterates exactly samples_per_epoch / effective_batch +- `def test_pipeline_seed_yields_reproducible_first_batch(tmp_path, fake_image_reader)` (line 142): With seed set and num_workers=0, two independent runs must produce + +## experiments/LongitudinalRegistration/uniGradICON/tests/finetuning/test_samplers.py + +- `def test_random_pair_sampler_rejects_short_keys()` (line 8) +- `def test_random_pair_sampler_partner_is_never_anchor()` (line 13) +- `def test_random_pair_sampler_full_coverage()` (line 24): Every non-anchor key must be reachable as a partner for every anchor. +- `def test_random_pair_sampler_uniformity()` (line 39): Each non-anchor key should be picked with equal probability. +- `def test_subject_pair_sampler_rejects_when_no_subject_has_pair()` (line 60) +- `def test_subject_pair_sampler_partner_respects_subject()` (line 66) +- `def test_subject_pair_sampler_filters_keys_without_pairs()` (line 86): Keys missing from the entry list (or without subject_id) are dropped. + +## experiments/LongitudinalRegistration/uniGradICON/tests/test_command_arguments.py + +- **class TestCommandInterface** (line 11) + - `def __init__(self, methodName='runTest')` (line 12) + - `def test_register_unigradicon_inference(self)` (line 20) + - `def test_register_multigradicon_inference(self)` (line 64) + +## experiments/LongitudinalRegistration/uniGradICON/tests/test_itk_interface.py + +- **class TestItkInterface** (line 15) + - `def __init__(self, methodName='runTest')` (line 16) + - `def test_register_pair(self)` (line 22) + - `def test_preprocessing_mri(self)` (line 66) + - `def test_preprocessing_ct(self)` (line 85) + - `def test_itk_registration(self)` (line 104) + - `def test_register_pair_with_mask_masking(self)` (line 171): Test register_pair_with_mask with loss_function_masking (mask_A/B only). + - `def test_register_pair_with_mask_dice(self)` (line 190): Test register_pair_with_mask with dice_loss_weight (segmentation_A/B only). + - `def test_register_pair_with_mask_both(self)` (line 209): Test register_pair_with_mask with both mask and segmentation. + - `def test_register_pair_with_mask_images_only(self)` (line 230): ``register_pair_with_mask`` with no mask or segmentation kwargs + - `def test_itk_warp(self)` (line 248) + +## experiments/LongitudinalRegistration/uniGradICON/tests/test_requirements_sync.py + +- **class TestImports** (line 4) + - `def test_requirements_match_cfg(self)` (line 6) + +## experiments/LongitudinalRegistration/uniGradICON/training/dataset.py + +- **class COPDDataset** (line 13) + - `def __init__(self, phase='train', scale='2xdown', data_path=f'{DATASET_DIR}/half_res_preprocessed_transposed_SI', ROI_only=False, data_num=-1, desired_shape=None, device='cpu')` (line 14) + - `def process(self, img, desired_shape=None, device='cpu', seg=None)` (line 46) +- **class OAIDataset** (line 65) + - `def __init__(self, phase='train', scale='2xdownsample', data_path=f'{DATASET_DIR}/OAI', data_num=1000, desired_shape=None, device='cpu')` (line 66) + - `def process(self, img, desired_shape=None, device='cpu')` (line 93) +- **class HCPDataset** (line 114) + - `def __init__(self, phase='train', scale='2xdown', data_path=f'{DATASET_DIR}/HCP', data_num=1000, desired_shape=None, device='cpu')` (line 115) + - `def process(self, img, desired_shape=None, device='cpu')` (line 141) +- **class L2rAbdomenDataset** (line 161) + - `def __init__(self, data_path=f'{DATASET_DIR}/AbdomenCTCT', data_num=1000, desired_shape=None, device='cpu')` (line 162) + - `def process(self, img, desired_shape=None, device='cpu')` (line 180) +- **class L2rThoraxCBCTDataset** (line 198) + - `def __init__(self, data_path=f'{DATASET_DIR}/ThoraxCBCT', data_num=1000, desired_shape=None, device='cpu')` (line 199) + - `def process(self, img, desired_shape=None, device='cpu')` (line 224) +- **class ACDCDataset** (line 239) + - `def __init__(self, data_path=f'{DATASET_DIR}/ACDC', desired_shape=None)` (line 240) + - `def process(self, img, desired_shape=None)` (line 260) + +## experiments/LongitudinalRegistration/uniGradICON/training/dataset_multi.py + +- **class COPDDataset** (line 16) + - `def __init__(self, scale='2xdown', data_path=f'{DATASET_DIR}/half_res_preprocessed_transposed_SI', ROI_only=False, data_num=-1, desired_shape=None, device='cpu', return_labels=False)` (line 17) + - `def pack_and_process_image(self, img, seg=None)` (line 45) + - `def process(self, img, desired_shape=None, device='cpu', seg=None)` (line 50) +- **class BratsRegDataset** (line 72) + - `def __init__(self, data_path=f'{DATASET_DIR}/BraTS-Reg/BraTSReg_Training_Data_v3/', data_num=1000, desired_shape=None, device='cpu', return_labels=False, randomization='random')` (line 73) + - `def pack_and_process_image(self, image)` (line 118) + - `def process(self, img, desired_shape=None, device='cpu')` (line 123) +- **class L2rAbdomenDataset** (line 154) + - `def __init__(self, data_path=f'{DATASET_DIR}/AbdomenCTCT', data_num=1000, desired_shape=None, device='cpu', return_labels=False, randomization='random', augmentation=True)` (line 155) + - `def pack_and_process_image(self, case_path, invert=False)` (line 189) + - `def process(self, img, desired_shape=None, device='cpu')` (line 196) +- **class HCPDataset** (line 225) + - `def __init__(self, scale='2xdown', data_path=f'{DATASET_DIR}/ICON_brain_preprocessed_data', data_num=1000, desired_shape=None, device='cpu', return_labels=False, randomization='random')` (line 226) + - `def pack_and_process_image(self, image)` (line 259) + - `def process(self, img, desired_shape=None, device='cpu')` (line 264) +- **class ABCDFAMDDataset** (line 296) + - `def __init__(self, phase='train', data_path=f'{DATASET_DIR}/dti_scalars', data_num=1000, desired_shape=None, device='cpu', return_labels=False, randomization='random')` (line 297) + - `def pack_and_process_image(self, image)` (line 341) + - `def process(self, img, desired_shape=None, device='cpu')` (line 346) +- **class ABCDDataset** (line 378) + - `def __init__(self, phase='train', data_path=f'{DATASET_DIR}', data_num=1000, desired_shape=None, device='cpu', return_labels=False)` (line 379) + - `def pack_and_process_image(self, image)` (line 438) + - `def process(self, img, desired_shape=None, device='cpu')` (line 443) +- **class OAIMMDataset** (line 465) + - `def __init__(self, data_path=f'{DATASET_DIR}/oai', data_num=1000, desired_shape=None, device='cpu', return_labels=False)` (line 466) + - `def pack_and_process_image(self, image)` (line 491) + - `def process(self, img, desired_shape=None, device='cpu')` (line 496) +- **class L2rMRCTDataset** (line 518) + - `def __init__(self, data_path=f'{DATASET_DIR}/AbdomenMRCT/', data_num=1000, desired_shape=None, device='cpu', phase='train', augmentation=True, return_labels=False)` (line 519) + - `def pack(self, image)` (line 568) + - `def process_label(self, label, desired_shape=None, device='cpu')` (line 571) + - `def process_ct(self, img, desired_shape=None, device='cpu')` (line 577) + - `def process_mr(self, img, desired_shape=None, device='cpu')` (line 584) +- **class UKBiobankDataset** (line 612) + - `def __init__(self, data_path=f'{DATASET_DIR}/uk-biobank/', data_num=1000, desired_shape=None, device='cpu', phase='train', return_labels=False, randomization='random')` (line 613) + - `def pack_and_process_image(self, image)` (line 648) + - `def process(self, img, desired_shape=None, device='cpu')` (line 653) +- **class PancreasDataset** (line 691) + - `def __init__(self, phase='train', data_path=f'{DATASET_DIR}/pancreas/', data_num=1000, desired_shape=(175, 175, 175), device='cpu', return_labels=False)` (line 692) + - `def process(self, img, desired_shape=None, device='cpu')` (line 731) + - `def process_training_data(self, ct_img_arr, cb_img_arr)` (line 739) +- **class L2rThoraxCBCTDataset** (line 759) + - `def __init__(self, data_path=f'{DATASET_DIR}/ThoraxCBCT', data_num=1000, desired_shape=None, device='cpu', return_labels=False)` (line 760) + - `def pack_and_process_image(self, image)` (line 786) + - `def process(self, img, desired_shape=None, device='cpu')` (line 791) + +## experiments/LongitudinalRegistration/uniGradICON/training/train.py + +- `def write_stats(writer, stats, ite, prefix='')` (line 15) +- `def get_dataset()` (line 26) +- `def augment(image_A, image_B)` (line 37) +- `def train_kernel(optimizer, net, moving_image, fixed_image, writer, ite)` (line 85) +- `def train(net, optimizer, data_loader, val_data_loader, epochs=200, eval_period=-1, save_period=-1, step_callback=lambda net: None, unwrapped_net=None, data_augmenter=None)` (line 94): A training function intended for long running experiments, with tensorboard logging +- `def train_two_stage(input_shape, data_loader, val_data_loader, GPUS, epochs, eval_period, save_period, resume_from)` (line 195) + +## experiments/LongitudinalRegistration/uniGradICON/training/train_multi.py + +- `def write_stats(writer, stats, ite, prefix='')` (line 14) +- `def get_multi_training_set()` (line 25) +- `def get_multi_finetuning_set()` (line 49) +- `def augment(image_A, image_B, label_A, label_B)` (line 86) +- `def train_kernel(optimizer, net, moving_image, fixed_image, moving_label, fixed_label, writer, ite)` (line 136) +- `def train(net, optimizer, data_loader, val_data_loader, epochs=200, eval_period=-1, save_period=-1, step_callback=lambda net: None, unwrapped_net=None, data_augmenter=None)` (line 145): A training function intended for long running experiments, with tensorboard logging +- `def train_two_stage(input_shape, data_loader, val_data_loader, GPUS, epochs, eval_period, save_period, resume_from)` (line 245) +- `def finetune(net, data_loader, val_data_loader, GPUS, epochs, eval_period, save_period)` (line 317) + ## experiments/Lung-GatedCT_To_USD/data_dirlab_4d_ct.py - **class DataDirLab4DCT** (line 10): This class is used to store the data for the DirLab 4DCT dataset. diff --git a/pyproject.toml b/pyproject.toml index a386eb4..aa79a2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -227,6 +227,8 @@ module = [ "picsl_greedy", "pydicom", "pydicom.*", + "physicsnemo", + "physicsnemo.*", "pxr", "pxr.*", "pyvista", @@ -255,6 +257,7 @@ module = [ "physiomotion4d.cli.visualize_pca_modes", "physiomotion4d.vtk_to_usd.mesh_utils", "physiomotion4d.vtk_to_usd.vtk_reader", + "tutorial_09_physicsnemo_mesh_stage_model", ] disable_error_code = ["import-not-found", "import-untyped"] diff --git a/tutorials/__init__.py b/tutorials/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tutorials/tutorial_09_physicsnemo_mesh_stage_model.py b/tutorials/tutorial_09_physicsnemo_mesh_stage_model.py new file mode 100644 index 0000000..969d174 --- /dev/null +++ b/tutorials/tutorial_09_physicsnemo_mesh_stage_model.py @@ -0,0 +1,266 @@ +""" +Tutorial 9: Train a PhysicsNeMo model for DirLab mesh time-stage prediction. + +This tutorial uses the per-time-point PCA-fitted meshes created by Tutorial 8. +For each case, it trains a small PhysicsNeMo fully connected model that maps +reference mesh point coordinates and a requested normalized respiratory stage to +point displacements. The trained model can then predict a mesh at a new +user-specified stage without running image registration again. + +Data Required +------------- +Run Tutorial 8 first so ``output/tutorial_08_dirlab_pca_time_series`` contains +``Case*/meshes/*_pca_fit.vtp`` files. +""" + +# %% +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import pyvista as pv +import torch +from physicsnemo.models.mlp import FullyConnected + + +# nnUNetv2 (used by TotalSegmentator inside several workflows) spawns a +# multiprocessing.Pool. On Windows the spawn start method re-imports this +# script in each child; without the __name__ == "__main__" guard around +# top-level work, that re-import fires the segmenter again and Python's +# spawn-cascade detector raises RuntimeError. Wrapping consistently across +# tutorials also matches the style of tutorial_01. +if __name__ == "__main__": + # %% + REPO_ROOT = Path(__file__).resolve().parent.parent + TUTORIALS_DIR = Path(__file__).resolve().parent + TUTORIAL_08_OUTPUT_DIR = ( + TUTORIALS_DIR / "output" / "tutorial_08_dirlab_pca_time_series" + ) + OUTPUT_DIR = TUTORIALS_DIR / "output" / "tutorial_09_physicsnemo_mesh_stage_model" + TARGET_STAGE = 0.5 + CASE: Optional[int] = None + EPOCHS = 500 + POINTS_PER_MESH = 4096 + LEARNING_RATE = 1.0e-3 + LOG_LEVEL = logging.INFO + + DIRLAB_CASE_PREFIXES = [ + "Case1Pack", + "Case2Pack", + "Case3Pack", + "Case4Pack", + "Case5Pack", + "Case6Pack", + "Case7Pack", + "Case8Deploy", + "Case9Pack", + "Case10Pack", + ] + + def run_tutorial() -> dict[str, Any]: + """Train PhysicsNeMo stage models and predict meshes at ``target_stage``. + + Returns + ------- + dict[str, Any] + Per-case checkpoint, metadata, predicted mesh, and training loss paths. + """ + + tutorial_08_output_dir = TUTORIAL_08_OUTPUT_DIR + output_dir = OUTPUT_DIR + target_stage = TARGET_STAGE + case = CASE + epochs = EPOCHS + points_per_mesh = POINTS_PER_MESH + learning_rate = LEARNING_RATE + log_level = LOG_LEVEL + + logging.basicConfig(level=log_level) + if target_stage < 0.0 or target_stage > 1.0: + raise ValueError("target_stage must be in the normalized range [0.0, 1.0].") + + output_dir.mkdir(parents=True, exist_ok=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if case is None: + case_numbers = list(range(1, 11)) + else: + case_numbers = [case] + + tutorial_outputs: dict[str, Any] = {} + for case_number in case_numbers: + case_prefix = DIRLAB_CASE_PREFIXES[case_number - 1] + mesh_dir = tutorial_08_output_dir / case_prefix / "meshes" + mesh_files = sorted(mesh_dir.glob("*_pca_fit.vtp")) + if len(mesh_files) < 2: + print(f"Skipping {case_prefix}: fewer than two Tutorial 8 meshes found") + continue + + case_output_dir = output_dir / case_prefix + case_output_dir.mkdir(parents=True, exist_ok=True) + + reference_mesh = pv.read(str(mesh_files[0])) + reference_points = np.asarray(reference_mesh.points, dtype=np.float32) + if points_per_mesh <= 0 or points_per_mesh >= reference_mesh.n_points: + point_indices = np.arange(reference_mesh.n_points) + else: + point_indices = np.linspace( + 0, + reference_mesh.n_points - 1, + points_per_mesh, + dtype=np.int64, + ) + + coordinate_mean = reference_points.mean(axis=0) + coordinate_scale = reference_points.std(axis=0) + coordinate_scale = np.where(coordinate_scale == 0.0, 1.0, coordinate_scale) + normalized_reference_points = ( + reference_points[point_indices] - coordinate_mean + ) / coordinate_scale + + training_inputs: list[np.ndarray] = [] + training_targets: list[np.ndarray] = [] + stage_denominator = max(1, len(mesh_files) - 1) + for stage_index, mesh_file in enumerate(mesh_files): + mesh = pv.read(str(mesh_file)) + if mesh.n_points != reference_mesh.n_points: + raise ValueError( + f"{mesh_file} has {mesh.n_points} points, expected " + f"{reference_mesh.n_points}. Tutorial 8 meshes must share topology." + ) + + stage = stage_index / stage_denominator + stage_column = np.full((len(point_indices), 1), stage, dtype=np.float32) + training_inputs.append( + np.hstack([normalized_reference_points, stage_column]) + ) + training_targets.append( + np.asarray(mesh.points[point_indices], dtype=np.float32) + - reference_points[point_indices] + ) + + inputs_array = np.vstack(training_inputs).astype(np.float32) + targets_array = np.vstack(training_targets).astype(np.float32) + displacement_scale = float(np.max(np.abs(targets_array))) + if displacement_scale == 0.0: + displacement_scale = 1.0 + targets_array = targets_array / displacement_scale + + inputs_tensor = torch.from_numpy(inputs_array).to(device) + targets_tensor = torch.from_numpy(targets_array).to(device) + + model = FullyConnected( + in_features=4, + layer_size=128, + out_features=3, + num_layers=4, + activation_fn="silu", + skip_connections=True, + ).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + loss_function = torch.nn.MSELoss() + + losses: list[float] = [] + model.train() + for epoch in range(epochs): + optimizer.zero_grad() + prediction = model(inputs_tensor) + loss = loss_function(prediction, targets_tensor) + loss.backward() + optimizer.step() + losses.append(float(loss.detach().cpu())) + if epoch == 0 or (epoch + 1) % 100 == 0 or epoch + 1 == epochs: + print( + f"{case_prefix} epoch {epoch + 1:04d}/{epochs}: " + f"loss={losses[-1]:.6f}" + ) + + model.eval() + all_normalized_points = ( + reference_points - coordinate_mean + ) / coordinate_scale + all_stage_column = np.full( + (reference_mesh.n_points, 1), + target_stage, + dtype=np.float32, + ) + prediction_inputs = np.hstack([all_normalized_points, all_stage_column]) + predicted_displacements: list[np.ndarray] = [] + with torch.no_grad(): + for start in range(0, reference_mesh.n_points, 65536): + stop = min(start + 65536, reference_mesh.n_points) + prediction_tensor = torch.from_numpy( + prediction_inputs[start:stop].astype(np.float32) + ).to(device) + predicted_displacements.append( + model(prediction_tensor).cpu().numpy() * displacement_scale + ) + + predicted_mesh = reference_mesh.copy(deep=True) + predicted_mesh.points = reference_points + np.vstack( + predicted_displacements + ) + + stage_tag = f"{target_stage:.3f}".replace(".", "p") + checkpoint_file = case_output_dir / "physicsnemo_stage_model.pt" + metadata_file = case_output_dir / "physicsnemo_stage_model_metadata.json" + losses_file = case_output_dir / "training_losses.json" + predicted_mesh_file = ( + case_output_dir / f"{case_prefix}_stage_{stage_tag}.vtp" + ) + + torch.save( + { + "model_state_dict": model.state_dict(), + "coordinate_mean": coordinate_mean.tolist(), + "coordinate_scale": coordinate_scale.tolist(), + "displacement_scale": displacement_scale, + "target_stage": target_stage, + "mesh_files": [str(mesh_file) for mesh_file in mesh_files], + }, + checkpoint_file, + ) + metadata_file.write_text( + json.dumps( + { + "architecture": "physicsnemo.models.mlp.FullyConnected", + "input_features": [ + "reference_x_normalized", + "reference_y_normalized", + "reference_z_normalized", + "normalized_stage", + ], + "output_features": ["dx", "dy", "dz"], + "target_stage": target_stage, + "epochs": epochs, + "points_per_mesh": len(point_indices), + "learning_rate": learning_rate, + "coordinate_mean": coordinate_mean.tolist(), + "coordinate_scale": coordinate_scale.tolist(), + "displacement_scale": displacement_scale, + "training_meshes": [str(mesh_file) for mesh_file in mesh_files], + }, + indent=2, + ), + encoding="utf-8", + ) + losses_file.write_text(json.dumps(losses, indent=2), encoding="utf-8") + predicted_mesh.save(predicted_mesh_file) + + tutorial_outputs[case_prefix] = { + "checkpoint_file": checkpoint_file, + "metadata_file": metadata_file, + "losses_file": losses_file, + "predicted_mesh_file": predicted_mesh_file, + "final_loss": losses[-1], + } + + return tutorial_outputs + + # %% + # Run this cell in VS Code or Cursor: + tutorial_results = run_tutorial() From d1c8a1a3b9423b279a23c206a68791040772ec12 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Wed, 20 May 2026 07:43:03 -0400 Subject: [PATCH 2/2] ENH: Code stability suggestions from AI review --- .pre-commit-config.yaml | 2 +- ...utorial_09_physicsnemo_mesh_stage_model.py | 24 +++++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 539e1af..023d75e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: rev: v0.14.14 hooks: # Run the linter with auto-fixes - - id: ruff + - id: ruff-check args: [--fix] # Run the formatter - id: ruff-format diff --git a/tutorials/tutorial_09_physicsnemo_mesh_stage_model.py b/tutorials/tutorial_09_physicsnemo_mesh_stage_model.py index 969d174..3a83dc5 100644 --- a/tutorials/tutorial_09_physicsnemo_mesh_stage_model.py +++ b/tutorials/tutorial_09_physicsnemo_mesh_stage_model.py @@ -35,7 +35,6 @@ # tutorials also matches the style of tutorial_01. if __name__ == "__main__": # %% - REPO_ROOT = Path(__file__).resolve().parent.parent TUTORIALS_DIR = Path(__file__).resolve().parent TUTORIAL_08_OUTPUT_DIR = ( TUTORIALS_DIR / "output" / "tutorial_08_dirlab_pca_time_series" @@ -86,18 +85,33 @@ def run_tutorial() -> dict[str, Any]: output_dir.mkdir(parents=True, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + num_cases = len(DIRLAB_CASE_PREFIXES) if case is None: - case_numbers = list(range(1, 11)) + case_numbers = list(range(1, num_cases + 1)) else: + if not 1 <= case <= num_cases: + raise ValueError( + f"CASE={case} is out of range; must be an integer between 1 " + f"and {num_cases} (inclusive)." + ) case_numbers = [case] tutorial_outputs: dict[str, Any] = {} for case_number in case_numbers: case_prefix = DIRLAB_CASE_PREFIXES[case_number - 1] mesh_dir = tutorial_08_output_dir / case_prefix / "meshes" - mesh_files = sorted(mesh_dir.glob("*_pca_fit.vtp")) + mesh_files = ( + sorted(mesh_dir.glob("*_pca_fit.vtp")) if mesh_dir.exists() else [] + ) if len(mesh_files) < 2: - print(f"Skipping {case_prefix}: fewer than two Tutorial 8 meshes found") + message = ( + f"Tutorial 8 output for {case_prefix} is missing or incomplete: " + f"found {len(mesh_files)} '*_pca_fit.vtp' file(s) in {mesh_dir}, " + "expected at least 2. Run Tutorial 8 before Tutorial 9." + ) + if case is not None: + raise FileNotFoundError(message) + logging.info(f"Skipping {case_prefix}: {message}") continue case_output_dir = output_dir / case_prefix @@ -174,7 +188,7 @@ def run_tutorial() -> dict[str, Any]: optimizer.step() losses.append(float(loss.detach().cpu())) if epoch == 0 or (epoch + 1) % 100 == 0 or epoch + 1 == epochs: - print( + logging.info( f"{case_prefix} epoch {epoch + 1:04d}/{epochs}: " f"loss={losses[-1]:.6f}" )