diff --git a/auto_round/utils/missing_tensors.py b/auto_round/utils/missing_tensors.py index 94f3d7be7..dd82c537b 100644 --- a/auto_round/utils/missing_tensors.py +++ b/auto_round/utils/missing_tensors.py @@ -617,6 +617,9 @@ def _update_qcfg(qcfg, show_log=True): ) if weight_keys: existing = qcfg.get("block_name_to_quantize") or [] + # In cases where this attribute is lacking + if not existing: + return qcfg if isinstance(existing, str): existing = [b.strip() for b in existing.split(",") if b.strip()] existing_set = set(existing) diff --git a/test/test_cpu/utils/test_missing_tensors.py b/test/test_cpu/utils/test_missing_tensors.py index 190404f86..a167ff908 100644 --- a/test/test_cpu/utils/test_missing_tensors.py +++ b/test/test_cpu/utils/test_missing_tensors.py @@ -248,8 +248,8 @@ class TestCopyMissingTensorsFromSource(unittest.TestCase): Special case: 'lm_head.weight' is always excluded. """ - def _make_auto_round_config(self, bits=4, group_size=128, sym=True) -> dict: - return { + def _make_auto_round_config(self, bits=4, group_size=128, sym=True, block_name_to_quantize=None) -> dict: + qcfg = { "quantization_config": { "quant_method": "auto-round", "packing_format": "auto_round:auto_gptq", @@ -258,6 +258,9 @@ def _make_auto_round_config(self, bits=4, group_size=128, sym=True) -> dict: "sym": sym, } } + if block_name_to_quantize is not None: + qcfg["quantization_config"]["block_name_to_quantize"] = block_name_to_quantize + return qcfg # ------------------------------------------------------------------ # # Detection correctness # @@ -489,6 +492,34 @@ def test_block_name_to_ignore_not_quantized_in_woq_mode(self): self.assertIn("mtp.0.mlp.gate.weight", result) self.assertNotIn("mtp.0.mlp.gate.qweight", result) + def test_config_updated_without_block_name_to_quantize_after_woq(self): + """After WOQ, config.json is updated so the new block appears in block_name_to_quantize.""" + out_features, in_features = 32, 128 + with tempfile.TemporaryDirectory() as source_dir, tempfile.TemporaryDirectory() as target_dir: + _save_safetensors( + {"mtp.0.ffn.weight": torch.randn(out_features, in_features)}, + os.path.join(source_dir, "model.safetensors"), + ) + _save_safetensors( + {"model.embed_tokens.weight": torch.randn(8, 64)}, + os.path.join(target_dir, "model.safetensors"), + ) + raw_qcfg = self._make_auto_round_config(bits=4, group_size=128, sym=True) + with open(os.path.join(target_dir, "config.json"), "w") as f: + json.dump(raw_qcfg, f) + + copy_missing_tensors_from_source(source_dir, target_dir) + + with open(os.path.join(target_dir, "config.json")) as f: + updated_cfg = json.load(f) + + qcfg = updated_cfg.get("quantization_config", {}) + block_names = qcfg.get("block_name_to_quantize", []) + self.assertTrue( + block_names == [], + f"Expected no block_name_to_quantize, got: {block_names}", + ) + def test_config_updated_with_block_name_to_quantize_after_woq(self): """After WOQ, config.json is updated so the new block appears in block_name_to_quantize.""" out_features, in_features = 32, 128 @@ -501,8 +532,11 @@ def test_config_updated_with_block_name_to_quantize_after_woq(self): {"model.embed_tokens.weight": torch.randn(8, 64)}, os.path.join(target_dir, "model.safetensors"), ) + raw_qcfg = self._make_auto_round_config( + bits=4, group_size=128, sym=True, block_name_to_quantize=["model.layers"] + ) with open(os.path.join(target_dir, "config.json"), "w") as f: - json.dump(self._make_auto_round_config(bits=4, group_size=128, sym=True), f) + json.dump(raw_qcfg, f) copy_missing_tensors_from_source(source_dir, target_dir)