This repository was archived by the owner on Jan 24, 2024. It is now read-only.
File tree 4 files changed +15
-1
lines changed
4 files changed +15
-1
lines changed Original file line number Diff line number Diff line change @@ -18,6 +18,7 @@ def is_true(value):
18
18
PORT = int (os .getenv ("PORT" , "80" ))
19
19
20
20
# Model-related arguments:
21
+ MODEL_PEFT = is_true (os .getenv ("MODEL_PEFT" , "" ))
21
22
MODEL_REVISION = os .getenv ("MODEL_REVISION" , "" )
22
23
MODEL_CACHE_DIR = os .getenv ("MODEL_CACHE_DIR" , "models" )
23
24
MODEL_LOAD_IN_8BIT = is_true (os .getenv ("MODEL_LOAD_IN_8BIT" , "" ))
Original file line number Diff line number Diff line change 23
23
from . import MODEL_LOAD_IN_4BIT
24
24
from . import MODEL_4BIT_QUANT_TYPE
25
25
from . import MODEL_4BIT_DOUBLE_QUANT
26
+ from . import MODEL_PEFT
26
27
from . import MODEL_LOCAL_FILES_ONLY
27
28
from . import MODEL_TRUST_REMOTE_CODE
28
29
from . import MODEL_HALF_PRECISION
44
45
name_or_path = MODEL ,
45
46
revision = MODEL_REVISION ,
46
47
cache_dir = MODEL_CACHE_DIR ,
48
+ is_peft = MODEL_PEFT ,
47
49
load_in_8bit = MODEL_LOAD_IN_8BIT ,
48
50
load_in_4bit = MODEL_LOAD_IN_4BIT ,
49
51
quant_type = MODEL_4BIT_QUANT_TYPE ,
Original file line number Diff line number Diff line change 14
14
TopPLogitsWarper ,
15
15
BitsAndBytesConfig
16
16
)
17
+ from peft import (
18
+ PeftConfig ,
19
+ PeftModel
20
+ )
17
21
18
22
from .choice import map_choice
19
23
from .tokenizer import StreamTokenizer
@@ -310,6 +314,7 @@ def load_model(
310
314
name_or_path ,
311
315
revision = None ,
312
316
cache_dir = None ,
317
+ is_peft = False ,
313
318
load_in_8bit = False ,
314
319
load_in_4bit = False ,
315
320
quant_type = "fp4" ,
@@ -327,7 +332,6 @@ def load_model(
327
332
kwargs ["revision" ] = revision
328
333
if cache_dir :
329
334
kwargs ["cache_dir" ] = cache_dir
330
- tokenizer = AutoTokenizer .from_pretrained (name_or_path , ** kwargs )
331
335
332
336
# Set device mapping and quantization options if CUDA is available.
333
337
if torch .cuda .is_available ():
@@ -354,6 +358,12 @@ def load_model(
354
358
if half_precision or load_in_8bit or load_in_4bit :
355
359
kwargs ["torch_dtype" ] = torch .float16
356
360
361
+ if is_peft :
362
+ peft_config = PeftConfig .from_pretrained (name_or_path )
363
+ name_or_path = peft_config .base_model_name_or_path
364
+
365
+ tokenizer = AutoTokenizer .from_pretrained (name_or_path , ** kwargs )
366
+
357
367
# Support both decoder-only and encoder-decoder models.
358
368
try :
359
369
model = AutoModelForCausalLM .from_pretrained (name_or_path , ** kwargs )
Original file line number Diff line number Diff line change @@ -12,3 +12,4 @@ safetensors~=0.3.1
12
12
torch >= 1.12.1
13
13
transformers [sentencepiece ]~= 4.30.1
14
14
waitress ~= 2.1.2
15
+ peft ~= 0.3.0
You can’t perform that action at this time.
0 commit comments