Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions PW_FT_classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,27 @@ def main(
"""

# GPU configuration: set up GPUs based on availability and user specification
gpus = gpus if torch.cuda.is_available() else None
gpus = [int(i) for i in gpus.split(',')]
if torch.cuda.is_available():
# Sanitize and validate the GPU list provided via CLI
if gpus is None:
parsed_gpus = None
else:
tokens = [token.strip() for token in str(gpus).split(",") if token.strip()]
if not tokens:
# Empty or whitespace-only input: fall back to CPU
parsed_gpus = None
else:
try:
parsed_gpus = [int(token) for token in tokens]
except ValueError as exc:
raise typer.BadParameter(
f"Invalid GPU list '{gpus}'. Expected a comma-separated list of integers, e.g. '0,1,2'."
) from exc
gpus = parsed_gpus
else:
# If no CUDA devices are available, set gpus to None to indicate CPU usage
# PyTorch Lightning Trainer will default to CPU if devices is None
gpus = None
Comment on lines +58 to +78
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On CPU-only systems this still looks likely to crash because the Trainer is configured with accelerator='gpu' unconditionally (line 162). Setting gpus=None will not make Lightning fall back to CPU when the accelerator is explicitly GPU. Consider deriving both accelerator and devices from torch.cuda.is_available() (e.g., accelerator='cpu', devices=1 when CUDA is unavailable, or use accelerator='auto'/devices='auto'). Also, the inline comment about defaulting to CPU when devices is None is misleading in the current configuration.

Copilot uses AI. Check for mistakes.

# Environment variable setup for numpy multi-threading. It is important to avoid cpu and ram issues.
os.environ["OMP_NUM_THREADS"] = str(np_threads)
Expand Down