Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions auto_round/utils/missing_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,9 @@ def _update_qcfg(qcfg, show_log=True):
)
if weight_keys:
existing = qcfg.get("block_name_to_quantize") or []
# In cases where this attribute is lacking
if not existing:
return qcfg
if isinstance(existing, str):
Comment thread
xin3he marked this conversation as resolved.
existing = [b.strip() for b in existing.split(",") if b.strip()]
existing_set = set(existing)
Expand Down
40 changes: 37 additions & 3 deletions test/test_cpu/utils/test_missing_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ class TestCopyMissingTensorsFromSource(unittest.TestCase):
Special case: 'lm_head.weight' is always excluded.
"""

def _make_auto_round_config(self, bits=4, group_size=128, sym=True) -> dict:
return {
def _make_auto_round_config(self, bits=4, group_size=128, sym=True, block_name_to_quantize=None) -> dict:
qcfg = {
"quantization_config": {
"quant_method": "auto-round",
"packing_format": "auto_round:auto_gptq",
Expand All @@ -258,6 +258,9 @@ def _make_auto_round_config(self, bits=4, group_size=128, sym=True) -> dict:
"sym": sym,
}
}
if block_name_to_quantize is not None:
qcfg["quantization_config"]["block_name_to_quantize"] = block_name_to_quantize
return qcfg

# ------------------------------------------------------------------ #
# Detection correctness #
Expand Down Expand Up @@ -489,6 +492,34 @@ def test_block_name_to_ignore_not_quantized_in_woq_mode(self):
self.assertIn("mtp.0.mlp.gate.weight", result)
self.assertNotIn("mtp.0.mlp.gate.qweight", result)

def test_config_updated_without_block_name_to_quantize_after_woq(self):
"""After WOQ, config.json is updated so the new block appears in block_name_to_quantize."""
out_features, in_features = 32, 128
with tempfile.TemporaryDirectory() as source_dir, tempfile.TemporaryDirectory() as target_dir:
_save_safetensors(
{"mtp.0.ffn.weight": torch.randn(out_features, in_features)},
os.path.join(source_dir, "model.safetensors"),
)
_save_safetensors(
{"model.embed_tokens.weight": torch.randn(8, 64)},
os.path.join(target_dir, "model.safetensors"),
)
raw_qcfg = self._make_auto_round_config(bits=4, group_size=128, sym=True)
with open(os.path.join(target_dir, "config.json"), "w") as f:
json.dump(raw_qcfg, f)

copy_missing_tensors_from_source(source_dir, target_dir)

with open(os.path.join(target_dir, "config.json")) as f:
updated_cfg = json.load(f)

qcfg = updated_cfg.get("quantization_config", {})
block_names = qcfg.get("block_name_to_quantize", [])
self.assertTrue(
block_names == [],
f"Expected no block_name_to_quantize, got: {block_names}",
)

def test_config_updated_with_block_name_to_quantize_after_woq(self):
"""After WOQ, config.json is updated so the new block appears in block_name_to_quantize."""
out_features, in_features = 32, 128
Expand All @@ -501,8 +532,11 @@ def test_config_updated_with_block_name_to_quantize_after_woq(self):
{"model.embed_tokens.weight": torch.randn(8, 64)},
os.path.join(target_dir, "model.safetensors"),
)
raw_qcfg = self._make_auto_round_config(
bits=4, group_size=128, sym=True, block_name_to_quantize=["model.layers"]
)
with open(os.path.join(target_dir, "config.json"), "w") as f:
json.dump(self._make_auto_round_config(bits=4, group_size=128, sym=True), f)
json.dump(raw_qcfg, f)

copy_missing_tensors_from_source(source_dir, target_dir)

Expand Down
Loading