diff --git a/fastchat/train/train_lora.py b/fastchat/train/train_lora.py index 9ecb47c29..61b0437e6 100644 --- a/fastchat/train/train_lora.py +++ b/fastchat/train/train_lora.py @@ -92,9 +92,9 @@ def get_peft_state_maybe_zero_3(named_params, bias): lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t - for k, t in maybe_lora_bias: - if bias_name in lora_bias_names: - to_return[bias_name] = t + for k, t in maybe_lora_bias.items(): + if k in lora_bias_names: + to_return[k] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}