Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit ef138a0

Browse files
committed
Support loading PEFT (LoRA) models
1 parent e4dabb0 commit ef138a0

File tree

4 files changed

+15
-1
lines changed

4 files changed

+15
-1
lines changed

basaran/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def is_true(value):
1818
PORT = int(os.getenv("PORT", "80"))
1919

2020
# Model-related arguments:
21+
MODEL_PEFT = is_true(os.getenv("MODEL_PEFT", ""))
2122
MODEL_REVISION = os.getenv("MODEL_REVISION", "")
2223
MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "models")
2324
MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", ""))

basaran/__main__.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from . import MODEL_LOAD_IN_4BIT
2424
from . import MODEL_4BIT_QUANT_TYPE
2525
from . import MODEL_4BIT_DOUBLE_QUANT
26+
from . import MODEL_PEFT
2627
from . import MODEL_LOCAL_FILES_ONLY
2728
from . import MODEL_TRUST_REMOTE_CODE
2829
from . import MODEL_HALF_PRECISION
@@ -44,6 +45,7 @@
4445
name_or_path=MODEL,
4546
revision=MODEL_REVISION,
4647
cache_dir=MODEL_CACHE_DIR,
48+
is_peft=MODEL_PEFT,
4749
load_in_8bit=MODEL_LOAD_IN_8BIT,
4850
load_in_4bit=MODEL_LOAD_IN_4BIT,
4951
quant_type=MODEL_4BIT_QUANT_TYPE,

basaran/model.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
TopPLogitsWarper,
1515
BitsAndBytesConfig
1616
)
17+
from peft import (
18+
PeftConfig,
19+
PeftModel
20+
)
1721

1822
from .choice import map_choice
1923
from .tokenizer import StreamTokenizer
@@ -310,6 +314,7 @@ def load_model(
310314
name_or_path,
311315
revision=None,
312316
cache_dir=None,
317+
is_peft=False,
313318
load_in_8bit=False,
314319
load_in_4bit=False,
315320
quant_type="fp4",
@@ -327,7 +332,6 @@ def load_model(
327332
kwargs["revision"] = revision
328333
if cache_dir:
329334
kwargs["cache_dir"] = cache_dir
330-
tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)
331335

332336
# Set device mapping and quantization options if CUDA is available.
333337
if torch.cuda.is_available():
@@ -354,6 +358,12 @@ def load_model(
354358
if half_precision or load_in_8bit or load_in_4bit:
355359
kwargs["torch_dtype"] = torch.float16
356360

361+
if is_peft:
362+
peft_config = PeftConfig.from_pretrained(name_or_path)
363+
name_or_path = peft_config.base_model_name_or_path
364+
365+
tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)
366+
357367
# Support both decoder-only and encoder-decoder models.
358368
try:
359369
model = AutoModelForCausalLM.from_pretrained(name_or_path, **kwargs)

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ safetensors~=0.3.1
1212
torch>=1.12.1
1313
transformers[sentencepiece]~=4.30.1
1414
waitress~=2.1.2
15+
peft~=0.3.0

0 commit comments

Comments
 (0)