A research-grade, correctness-first distributed training framework in Python, inspired by the design space of PyTorch Distributed and Horovod.
This repository focuses on clear, testable implementations of core distributed training protocols before performance optimizations:
- deterministic single-node baseline
- synchronous Parameter Server training
- synchronous Ring All-Reduce training
- checkpointing and failure recovery
- benchmark harness with plots and interpretation
- correctness over speed
- explicit failure handling and assumptions
- modular design that can evolve into production-grade components
- MNIST MLP baseline training pipeline
- seeded deterministic behavior
- modular separation of config, model, and trainer
Key files:
src/dmlf/baseline/config.pysrc/dmlf/baseline/model.pysrc/dmlf/baseline/trainer.pyscripts/train_mnist_baseline.py
- explicit runtime guards for optional dependencies (
torch,torchvision) - actionable error messages when dependencies are missing
Key file:
src/dmlf/env.py
- torch-free unit-testable tensor abstraction (
MockTensor) - parameter-server aggregation and ring all-reduce logic tests
Key files:
src/dmlf/distributed/mock_tensor.pysrc/dmlf/distributed/algorithms.py
- per-parameter gradient capture immediately after backward
- structured snapshots for later synchronization layers
Key file:
src/dmlf/distributed/gradient_interceptor.py
- single-machine process simulation with data sharding
- independent worker replicas and local gradient computation
Key file:
src/dmlf/baseline/trainer.py
- one server process + multiple worker processes
- strict step/rank validation
- periodic checkpointing and restart-based recovery
- checksum validation to prevent silent checkpoint corruption
Key file:
src/dmlf/distributed/parameter_server_sync.py
- multi-process ring topology
- reduce-scatter + all-gather with chunk metadata checks
- strict message validation (step/phase/round/source)
Key file:
src/dmlf/distributed/ring_allreduce_sync.py
- throughput
- scaling efficiency
- convergence vs serial baseline
- estimated communication overhead
Key files:
scripts/run_distributed_benchmarks.pyexperiments/benchmarks/results/*docs/benchmarks/benchmark_report.md
Distributed ML Training Framework/
configs/
docs/
experiments/
scripts/
src/dmlf/
baseline/
distributed/
tests/
- Python 3.11+
matplotlibfor benchmark plotting- Optional:
torch,torchvisionfor baseline training and hook-based tests
python -m unittest discover -s tests -p "test_*.py" -vNotes:
- torch-dependent tests are skipped automatically when torch is not installed.
python scripts/train_mnist_baseline.py --epochs 3python scripts/run_sync_parameter_server.py --world-size 4 --num-steps 5 --parameter-dim 40 --shard-size 16python scripts/run_sync_ring_allreduce.py --world-size 4 --num-steps 6 --parameter-dim 8 --shard-size 5python scripts/run_distributed_benchmarks.pyOutputs are written to:
experiments/benchmarks/results/docs/benchmarks/benchmark_report.md
- deterministic synthetic-data protocol checks for distributed algorithms
- strict message validation in distributed runtime paths
- checkpoint checksum validation to detect corruption
- explicit hard-failure behavior for malformed protocol states
- CPU-only simulated communication (multiprocessing queues)
- no real multi-node networking stack in current implementation
- no GPU/NCCL/Gloo transport integration yet
- convergence claims are bounded to synthetic objective tests and baseline checks
- integrate real PyTorch model gradient synchronization path over current protocols
- add robust coordinator for elastic membership and retries
- introduce transport abstraction (local queue, TCP, and process-group backends)
- extend checkpoint state to optimizer/runtime metadata in PyTorch path
Add your preferred license in LICENSE (currently not set).