Skip to content

Text-supervised training for medical image segmentation using paired images and clinical reports.

License

Notifications You must be signed in to change notification settings

AImageLab-zip/Report-Guided-Segmentation

 
 

Repository files navigation

Medical Image Segmentation Framework

A flexible PyTorch-based framework for training 2D and 3D medical image segmentation models, with support for patch-based training, configurable architectures, and comprehensive metrics tracking.

Table of Contents

Project Structure

Brain-Segmentation/
├── base/                    # Abstract base classes
│   ├── base_dataset2d_sliced.py
│   ├── base_dataset.py
│   ├── base_model.py
│   └── base_trainer.py
├── config/                  # Configuration files for training and transforms
│   ├── config_atlas.json
│   ├── atlas_transforms.json
│   └── ...
├── datasets/                # Dataset loading and preprocessing (inherited from base_datasets)
│   ├── DatasetFactory.py
│   ├── ATLAS.py
│   └── BraTS2D.py
├── losses/                  # Loss function implementations
│   └── LossFactory.py
├── metrics/                 # Metrics computation and tracking
│   ├── MetricsFactory.py
│   └── MetricsManager.py
├── models/                  # Model architectures (inherited from base_model)
│   ├── ModelFactory.py
│   ├── UNet2D.py
│   └── UNet3D.py
├── optimizers/              # Optimizer configurations
│   └── OptimizerFactory.py
├── trainer/                 # Training logic (inherited from base_trainer)
│   ├── trainer_2Dsliced.py
│   └── trainer_3D.py
├── transforms/              # Data augmentation and preprocessing
│   └── TransformsFactory.py
├── utils/                   # Utility functions
│   ├── util.py
│   └── pad_unpad.py
├── config.py               # Config file handler
├── main.py                 # Training entry point
└── requirements.txt        # Python dependencies

Installation

  1. Clone the repository:
git clone https://github.com/kev98/Medical-Image-Segmentation.git
cd Medical-Image-Segmentation
  1. Create and activate a virtual environment:
python -m venv .venv
source .venv/bin/activate
  1. Install dependencies:
pip install -r requirements.txt

Quick Start and Usage Examples

The following are some base examples. You can add other CLI parameters useful for your main.py (which must be the entrypoint for training).

Command Line Arguments

Command line arguments implemented in the provided main.py file:

  • --config: Path to configuration JSON file (required)
  • --epochs: Number of training epochs (required)
  • --save_path: Directory to save model checkpoints (required)
  • --validation: Enable validation during training (flag)
  • --resume: Resume training from last checkpoint (flag)
  • --debug: Enable debug mode with verbose output (flag)

Example of launch of main.py, training a 3D segmentation model, resuming checkpoints,

python main.py \
  --config config/config_atlas.json \
  --epochs 100 \
  --save_path /folder_containing_model_last.pth \
  --validation \
  --resume

Implement Your Own Training

To set up a complete training pipeline, follow these steps:


Detailed Components

For detailed documentation on each component, refer to the README files in their respective directories:

  • Base Classes - Abstract base classes for datasets, models, and trainers
  • Configuration - JSON configuration files for training and transforms
  • Datasets - Dataset loading and preprocessing
  • Losses - Loss function implementations
  • Metrics - Metrics computation and tracking
  • Models - Model architectures
  • Optimizers - Optimizer configurations
  • Trainers - Training logic
  • Transforms - Data augmentation and preprocessing
  • Utils - Utility functions

Notes

  • For patch-based training with 3D volumes, the framework uses TorchIO's Queue and GridSampler.
  • Metrics are automatically computed per-class and averaged.
  • Checkpoints are saved as model_last.pth and model_best.pth in the folder specified by the parameter --save_path.
  • The framework is compatible with PyTorch 2.3+ and uses TorchIO's SubjectsLoader for proper data handling.

About

Text-supervised training for medical image segmentation using paired images and clinical reports.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%