diff --git a/tests/lora/utils.py b/tests/lora/utils.py index efa49b9f4838..547dbc8a5fb3 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -635,7 +635,7 @@ def test_simple_inference_with_partial_text_lora(self): state_dict = { f"text_encoder.{module_name}": param for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() - if "text_model.encoder.layers.4" not in module_name + if "encoder.layers.4" not in module_name } if self.has_two_text_encoders or self.has_three_text_encoders: @@ -644,7 +644,7 @@ def test_simple_inference_with_partial_text_lora(self): { f"text_encoder_2.{module_name}": param for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() - if "text_model.encoder.layers.4" not in module_name + if "encoder.layers.4" not in module_name } ) @@ -776,8 +776,9 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): ) if "text_encoder" in self.pipeline_class._lora_loadable_modules: + text_encoder_root = getattr(pipe.text_encoder, "text_model", pipe.text_encoder) self.assertTrue( - pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, + text_encoder_root.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, "The scaling parameter has not been correctly restored!", )