From ce38b0f881ad1b952899716bf97fa81c39e6dd10 Mon Sep 17 00:00:00 2001 From: Taksh Date: Sat, 2 May 2026 06:47:14 +0530 Subject: [PATCH] fix: add weights_only=True to torch.load calls in apply_delta.py torch.load without weights_only=True uses pickle deserialization, which allows arbitrary code execution from a malicious checkpoint file. Pass weights_only=True to all four torch.load calls so only tensor data is loaded. Fixes #3777 --- fastchat/model/apply_delta.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastchat/model/apply_delta.py b/fastchat/model/apply_delta.py index ba1c06d48..d22159b1c 100644 --- a/fastchat/model/apply_delta.py +++ b/fastchat/model/apply_delta.py @@ -34,7 +34,7 @@ def split_files(model_path, tmp_path, split_size): part = 0 try: for file_path in tqdm(files): - state_dict = torch.load(file_path) + state_dict = torch.load(file_path, weights_only=True) new_state_dict = {} current_size = 0 @@ -87,19 +87,19 @@ def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): base_files = glob.glob(base_pattern) delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") delta_files = glob.glob(delta_pattern) - delta_state_dict = torch.load(delta_files[0]) + delta_state_dict = torch.load(delta_files[0], weights_only=True) print("Applying the delta") weight_map = {} total_size = 0 for i, base_file in tqdm(enumerate(base_files)): - state_dict = torch.load(base_file) + state_dict = torch.load(base_file, weights_only=True) file_name = f"pytorch_model-{i}.bin" for name, param in state_dict.items(): if name not in delta_state_dict: for delta_file in delta_files: - delta_state_dict = torch.load(delta_file) + delta_state_dict = torch.load(delta_file, weights_only=True) gc.collect() if name in delta_state_dict: break