Skip to content

Commit c6b3c29

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/remove-double-init
2 parents cda0288 + d43ea79 commit c6b3c29

17 files changed

+484
-324
lines changed

src/llmcompressor/__init__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838
active_session,
3939
callbacks,
4040
create_session,
41-
finalize,
42-
initialize,
4341
reset_session,
4442
)
45-
from llmcompressor.entrypoints import Oneshot, oneshot
43+
from llmcompressor.entrypoints import Oneshot, oneshot, train

src/llmcompressor/core/__init__.py

-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
active_session,
1414
callbacks,
1515
create_session,
16-
finalize,
17-
initialize,
1816
reset_session,
1917
)
2018
from llmcompressor.core.state import Data, Hardware, ModifiedState, State
@@ -35,8 +33,6 @@
3533
"create_session",
3634
"active_session",
3735
"reset_session",
38-
"initialize",
39-
"finalize",
4036
"apply",
4137
"callbacks",
4238
"LifecycleCallbacks",

src/llmcompressor/core/session_functions.py

+1-77
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
import threading
22
from contextlib import contextmanager
3-
from typing import Any, Dict, List, Optional, Union
3+
from typing import Any, Optional
44

55
from llmcompressor.core.events import EventType
66
from llmcompressor.core.session import CompressionSession
77
from llmcompressor.core.state import ModifiedState
8-
from llmcompressor.recipe import Recipe
98

109
__all__ = [
1110
"create_session",
1211
"active_session",
1312
"reset_session",
14-
"initialize",
15-
"finalize",
1613
"callbacks",
1714
"LifecycleCallbacks",
1815
]
@@ -58,79 +55,6 @@ def reset_session():
5855
session._lifecycle.reset()
5956

6057

61-
def initialize(
62-
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
63-
recipe_stage: Union[str, List[str], None] = None,
64-
recipe_args: Optional[Dict[str, Any]] = None,
65-
model: Optional[Any] = None,
66-
teacher_model: Optional[Any] = None,
67-
optimizer: Optional[Any] = None,
68-
attach_optim_callbacks: bool = True,
69-
train_data: Optional[Any] = None,
70-
val_data: Optional[Any] = None,
71-
test_data: Optional[Any] = None,
72-
calib_data: Optional[Any] = None,
73-
copy_data: bool = True,
74-
start: Optional[float] = None,
75-
steps_per_epoch: Optional[int] = None,
76-
batches_per_step: Optional[int] = None,
77-
**kwargs,
78-
) -> ModifiedState:
79-
"""
80-
A method to initialize the active session for sparsification
81-
82-
:param recipe: the recipe to use for the sparsification, can be a path to a
83-
recipe file, a raw recipe string, a recipe object, or a list of recipe objects.
84-
:param recipe_stage: the stage to target for the sparsification
85-
:param recipe_args: the args to use for overriding the recipe defaults
86-
:param model: the model to sparsify
87-
:param teacher_model: the teacher model to use for knowledge distillation
88-
:param optimizer: the optimizer to use for the sparsification
89-
:param attach_optim_callbacks: True to attach the optimizer callbacks to the
90-
sparsification lifecycle, False otherwise
91-
:param train_data: the training data to use for the sparsification
92-
:param val_data: the validation data to use for the sparsification
93-
:param test_data: the testing data to use for the sparsification
94-
:param calib_data: the calibration data to use for the sparsification
95-
:param copy_data: True to copy the data, False otherwise
96-
:param start: the start epoch to use for the sparsification
97-
:param steps_per_epoch: the number of steps per epoch to use for the
98-
sparsification
99-
:param batches_per_step: the number of batches per step to use for
100-
sparsification
101-
:param kwargs: additional kwargs to pass to the lifecycle's initialize method
102-
:return: the modified state of the active session after initializing
103-
"""
104-
return active_session().initialize(
105-
recipe=recipe,
106-
recipe_stage=recipe_stage,
107-
recipe_args=recipe_args,
108-
model=model,
109-
teacher_model=teacher_model,
110-
optimizer=optimizer,
111-
attach_optim_callbacks=attach_optim_callbacks,
112-
train_data=train_data,
113-
val_data=val_data,
114-
test_data=test_data,
115-
calib_data=calib_data,
116-
copy_data=copy_data,
117-
start=start,
118-
steps_per_epoch=steps_per_epoch,
119-
batches_per_step=batches_per_step,
120-
**kwargs,
121-
)
122-
123-
124-
def finalize(**kwargs) -> ModifiedState:
125-
"""
126-
Method to finalize the active session for sparsification
127-
128-
:param kwargs: additional kwargs to pass to the lifecycle's finalize method
129-
:return: the modified state of the active session after finalizing
130-
"""
131-
return active_session().finalize(**kwargs)
132-
133-
13458
class LifecycleCallbacks:
13559
"""
13660
A class for invoking lifecycle events for the active session

0 commit comments

Comments
 (0)