Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 008f8f0

Browse files
committed
Support loading PEFT (LoRA) models
1 parent a368310 commit 008f8f0

File tree

4 files changed

+13
-0
lines changed

4 files changed

+13
-0
lines changed

basaran/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def is_true(value):
1818
PORT = int(os.getenv("PORT", "80"))
1919

2020
# Model-related arguments:
21+
MODEL_PEFT = is_true(os.getenv("MODEL_PEFT", ""))
2122
MODEL_REVISION = os.getenv("MODEL_REVISION", "")
2223
MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "models")
2324
MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", ""))

basaran/__main__.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from . import MODEL_LOAD_IN_4BIT
2424
from . import MODEL_4BIT_QUANT_TYPE
2525
from . import MODEL_4BIT_DOUBLE_QUANT
26+
from . import MODEL_PEFT
2627
from . import MODEL_LOCAL_FILES_ONLY
2728
from . import MODEL_TRUST_REMOTE_CODE
2829
from . import MODEL_HALF_PRECISION
@@ -44,6 +45,7 @@
4445
name_or_path=MODEL,
4546
revision=MODEL_REVISION,
4647
cache_dir=MODEL_CACHE_DIR,
48+
is_peft=MODEL_PEFT,
4749
load_in_8bit=MODEL_LOAD_IN_8BIT,
4850
load_in_4bit=MODEL_LOAD_IN_4BIT,
4951
quant_type=MODEL_4BIT_QUANT_TYPE,

basaran/model.py

+9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
TopPLogitsWarper,
1515
BitsAndBytesConfig
1616
)
17+
from peft import (
18+
PeftConfig,
19+
PeftModel
20+
)
1721

1822
from .choice import map_choice
1923
from .tokenizer import StreamTokenizer
@@ -302,6 +306,7 @@ def load_model(
302306
name_or_path,
303307
revision=None,
304308
cache_dir=None,
309+
is_peft=False,
305310
load_in_8bit=False,
306311
load_in_4bit=False,
307312
quant_type="fp4",
@@ -346,6 +351,10 @@ def load_model(
346351
if half_precision or load_in_8bit or load_in_4bit:
347352
kwargs["torch_dtype"] = torch.float16
348353

354+
if is_peft:
355+
peft_config = PeftConfig.from_pretrained(name_or_path)
356+
name_or_path = peft_config.base_model_name_or_path
357+
349358
# Support both decoder-only and encoder-decoder models.
350359
try:
351360
model = AutoModelForCausalLM.from_pretrained(name_or_path, **kwargs)

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ safetensors~=0.3.1
1212
torch>=1.12.1
1313
transformers[sentencepiece]~=4.29.2
1414
waitress~=2.1.2
15+
peft~=0.3.0

0 commit comments

Comments
 (0)