Skip to content

Commit 02c9bae

Browse files
authored
Correct LoRA weights merging (#1784)
Correction of the merging code between the model's original layer weights and the LoRA model weights. This respect the principle of LoRA to dispose of the LoRA layers once we don't plan on training it more bur more importantly allows us to save and load the model as a ".keras" file.
1 parent 9f287fd commit 02c9bae

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora.py

+4
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,10 @@ def call(self, inputs):
587587
B_weights = value_lora_layer.B.kernel # (1, 12, 64) (b, c, d)
588588
increment_weights = tf.einsum("ab,bcd->acd", A_weights, B_weights) * (ALPHA / RANK)
589589
value_lora_layer.original_layer.kernel.assign_add(increment_weights)
590+
591+
# Put back in place the original layers with updated weights
592+
self_attention_layer._query_dense = query_lora_layer.original_layer
593+
self_attention_layer._value_dense = value_lora_layer.original_layer
590594

591595
"""
592596
We are now all set to generate text with our LoRA model :).

0 commit comments

Comments
 (0)