Predict the bacterial host genus of a plasmid directly from its nucleotide sequence, by fine-tuning a pretrained DNA foundation model (default: DNABERT-2) on PLSDB 2025.
The headline public API is:
from plasmid_host_range.predict import predict_host_genus
preds = predict_host_genus("path/to/plasmid.fasta", top_k=3)
for p in preds:
print(p.accession, p.top_genera, p.scores)cd ml/plasmid-host-range
python -m venv .venv && source .venv/bin/activate
pip install -e ".[dev]"Or with uv:
uv venv && source .venv/bin/activate
uv pip install -e ".[dev]"PLSDB is distributed on Figshare. The current release is PLSDB 2024_05_31_v2
(article 27252609).
The download command hits the public Figshare API and pulls only the files we need
(≈ 2 GB total, mostly sequences.fasta.bz2):
plasmid-host-range downloadThis fetches into data/raw/:
| File | Purpose |
|---|---|
sequences.fasta (decompressed from .bz2) |
Plasmid nucleotide sequences |
nuccore.csv |
Per-plasmid metadata (accession, taxon ID, length, …) |
taxonomy.csv |
Taxon ID → lineage (genus, species, …) |
biosample.csv |
Biosample-level host info (fallback source) |
README.md |
Upstream PLSDB column docs |
The download is resumable at the file level: re-running download skips files that
are already present and non-empty.
Then build the processed splits:
plasmid-host-range preprocess --top-n-genera 20This joins nuccore.csv with taxonomy.csv on taxon ID, filters plasmids by length,
labels the top-N most common host genera (plus Other), and writes
data/processed/{train,val,test}.parquet with columns
{accession, sequence, label, genus, species}, using a host-species-grouped
train/val/test split to avoid leaking near-identical plasmids across splits.
Smoke test (tiny subset, CPU-friendly):
plasmid-host-range train --config configs/smoke.yamlFull fine-tune (GPU recommended):
plasmid-host-range train --config configs/default.yamlplasmid-host-range evaluate --checkpoint checkpoints/bestProduces reports/test_metrics.json and reports/confusion_matrix.png. A k-mer + logistic
regression baseline is reported alongside, so any DL improvement is quantified rather than
assumed.
plasmid-host-range predict path/to/plasmid.fasta --top-k 3src/plasmid_host_range/
data/ # download, preprocess, splits, torch Dataset
model.py # HF AutoModelForSequenceClassification wrapper
train.py # fine-tuning loop
evaluate.py # metrics + confusion matrix
predict.py # predict_host_genus() public API
baselines.py # k-mer + logistic regression baseline
cli.py # typer CLI
tests/ # smoke + unit tests
configs/ # default.yaml, smoke.yaml