Skip to content

Commit 9349385

Browse files
authored
Merge pull request #287 from EliSchwartz/main
Fixed issue with granite-vision QLORA training and made it the default
2 parents fa40159 + fb4aa65 commit 9349385

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

notebooks/en/fine_tuning_granite_vision_sft_trl.ipynb

+8-6
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
"!pip install -q flash-attn --no-build-isolation\n",
7979
"\n",
8080
"try:\n",
81-
" from flash_attn.flash_attention import FlashAttention\n",
81+
" import flash_attn\n",
8282
" print(\"FlashAttention is installed\")\n",
8383
" USE_FLASH_ATTENTION = True\n",
8484
"except ImportError:\n",
@@ -639,16 +639,18 @@
639639
"source": [
640640
"from transformers import BitsAndBytesConfig\n",
641641
"\n",
642-
"USE_QLORA = False\n",
643-
"USE_LORA = False\n",
642+
"USE_QLORA = True\n",
643+
"USE_LORA = True\n",
644644
"\n",
645645
"if USE_QLORA:\n",
646646
" # BitsAndBytesConfig int-4 config\n",
647647
" bnb_config = BitsAndBytesConfig(\n",
648648
" load_in_4bit=True,\n",
649649
" bnb_4bit_use_double_quant=True,\n",
650650
" bnb_4bit_quant_type=\"nf4\",\n",
651-
" bnb_4bit_compute_dtype=torch.bfloat16\n",
651+
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
652+
" llm_int8_skip_modules=[\"vision_tower\", \"lm_head\"], # Skip problematic modules\n",
653+
" llm_int8_enable_fp32_cpu_offload=True\n",
652654
" )\n",
653655
"else:\n",
654656
" bnb_config = None\n",
@@ -693,7 +695,6 @@
693695
" r=8,\n",
694696
" lora_alpha=8,\n",
695697
" lora_dropout=0.1,\n",
696-
" # target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],\n",
697698
" target_modules=[name for name, _ in model.named_modules() if 'language_model' in name and '_proj' in name],\n",
698699
" use_dora=True,\n",
699700
" init_lora_weights=\"gaussian\"\n",
@@ -1052,7 +1053,8 @@
10521053
"outputs": [],
10531054
"source": [
10541055
"if USE_LORA:\n",
1055-
" model = model.merge_and_unload().to(torch.bfloat16)"
1056+
" from peft import PeftModel\n",
1057+
" model = PeftModel.from_pretrained(model, training_args.output_dir)"
10561058
]
10571059
},
10581060
{

0 commit comments

Comments
 (0)