From abd9aa0d3509fca1184c9aa3f78db62cf0296368 Mon Sep 17 00:00:00 2001 From: Tal Golan Date: Thu, 11 Jan 2024 15:03:03 +0200 Subject: [PATCH] load models even if you don't have a CPU --- robustness/model_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/robustness/model_utils.py b/robustness/model_utils.py index eb34998..a9b3e61 100644 --- a/robustness/model_utils.py +++ b/robustness/model_utils.py @@ -90,8 +90,11 @@ def make_and_restore_model(*_, arch, dataset, resume_path=None, checkpoint = None if resume_path and os.path.isfile(resume_path): print("=> loading checkpoint '{}'".format(resume_path)) - checkpoint = ch.load(resume_path, pickle_module=dill) - + if ch.cuda.is_available(): + checkpoint = ch.load(resume_path, pickle_module=dill) + else: + checkpoint = ch.load(resume_path, pickle_module=dill, map_location=ch.device('cpu')) + # Makes us able to load models saved with legacy versions state_dict_path = 'model' if not ('model' in checkpoint): @@ -107,7 +110,8 @@ def make_and_restore_model(*_, arch, dataset, resume_path=None, if parallel: model = ch.nn.DataParallel(model) - model = model.cuda() + if ch.cuda.is_available(): + model = model.cuda() return model, checkpoint