Skip to content

GPT2 Model performance dismal with enable_lora #2162

Open
@ashep29

Description

@ashep29

The following code trains fine - loss decreases, accuracy improves:

chosen_preset = "gpt2_base_en"

preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
chosen_preset,
sequence_length=128,
)

gpt2_causal_lm = keras_nlp.models.GPT2CausalLM.from_preset(
chosen_preset, preprocessor=preprocessor
)

learning_rate = keras.optimizers.schedules.PolynomialDecay(
5e-5,
decay_steps=train_ds.cardinality() * num_epochs,
end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimiser = keras.optimizers.Adam(learning_rate)
metrics = [keras.metrics.SparseCategoricalAccuracy()]

gpt2_causal_lm.compile(
optimizer=optimiser,
loss=loss,
weighted_metrics=metrics,
)
gpt2_causal_lm.fit(train_ds, epochs=num_epochs)


Unfortunately, when lora is enabled, the model training fails dismally - all settings remain the same, other than .quantize and enable_lora, but the loss increases and accuracy is extremely low (around 1% as opposed to 46% without lora).

gpt2_causal_lm_with_qlora = keras_nlp.models.GPT2CausalLM.from_preset(
chosen_preset, preprocessor=preprocessor
)

gpt2_causal_lm_with_qlora.quantize("int8")
gpt2_causal_lm_with_qlora.backbone.enable_lora(rank=4)

The model produces junk output, e.g.;

This

12

L

12

This

12

[26262626

12

[12.

[12.

[[.

[[. the r)

[2

Is there anything wrong with how I'm using lora or is it simply not implemented or available for use with GPT2?

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions