diff --git a/PW_FT_classification/main.py b/PW_FT_classification/main.py index ec66333a1..4caa2236e 100644 --- a/PW_FT_classification/main.py +++ b/PW_FT_classification/main.py @@ -55,8 +55,27 @@ def main( """ # GPU configuration: set up GPUs based on availability and user specification - gpus = gpus if torch.cuda.is_available() else None - gpus = [int(i) for i in gpus.split(',')] + if torch.cuda.is_available(): + # Sanitize and validate the GPU list provided via CLI + if gpus is None: + parsed_gpus = None + else: + tokens = [token.strip() for token in str(gpus).split(",") if token.strip()] + if not tokens: + # Empty or whitespace-only input: fall back to CPU + parsed_gpus = None + else: + try: + parsed_gpus = [int(token) for token in tokens] + except ValueError as exc: + raise typer.BadParameter( + f"Invalid GPU list '{gpus}'. Expected a comma-separated list of integers, e.g. '0,1,2'." + ) from exc + gpus = parsed_gpus + else: + # If no CUDA devices are available, set gpus to None to indicate CPU usage + # PyTorch Lightning Trainer will default to CPU if devices is None + gpus = None # Environment variable setup for numpy multi-threading. It is important to avoid cpu and ram issues. os.environ["OMP_NUM_THREADS"] = str(np_threads)