Skip to content

Commit 6ef2004

Browse files
committed
Merge branch 'main' into gpt_oss_sink
2 parents cf95e02 + 8677159 commit 6ef2004

31 files changed

+62
-151
lines changed

src/optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __post_init__(self, **kwargs):
6868
self.image_size = self.rbln_config.image_size
6969

7070
@classmethod
71-
def wrap_model_if_needed(
71+
def _wrap_model_if_needed(
7272
cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
7373
) -> torch.nn.Module:
7474
decoder_model = _VAECosmosDecoder(model)
@@ -98,7 +98,7 @@ def replaced_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9898

9999
compiled_models = {}
100100
if rbln_config.uses_encoder:
101-
encoder_model, decoder_model = cls.wrap_model_if_needed(model, rbln_config)
101+
encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
102102
enc_compiled_model = cls.compile(
103103
encoder_model,
104104
rbln_compile_config=rbln_config.compile_cfgs[0],
@@ -107,7 +107,7 @@ def replaced_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
107107
)
108108
compiled_models["encoder"] = enc_compiled_model
109109
else:
110-
decoder_model = cls.wrap_model_if_needed(model, rbln_config)
110+
decoder_model = cls._wrap_model_if_needed(model, rbln_config)
111111
dec_compiled_model = cls.compile(
112112
decoder_model,
113113
rbln_compile_config=rbln_config.compile_cfgs[-1],

src/optimum/rbln/diffusers/models/controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __post_init__(self, **kwargs):
118118
)
119119

120120
@classmethod
121-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
121+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
122122
use_encoder_hidden_states = False
123123
for down_block in model.down_blocks:
124124
if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):

src/optimum/rbln/diffusers/models/transformers/prior_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __post_init__(self, **kwargs):
7777
self.clip_std = artifacts["clip_std"]
7878

7979
@classmethod
80-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
80+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
8181
return _PriorTransformer(model).eval()
8282

8383
@classmethod

src/optimum/rbln/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def compute_embedding(
185185
)
186186

187187
@classmethod
188-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
188+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
189189
num_latent_frames = rbln_config.num_latent_frames
190190
latent_height = rbln_config.latent_height
191191
latent_width = rbln_config.latent_width

src/optimum/rbln/diffusers/models/transformers/transformer_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __post_init__(self, **kwargs):
7777
super().__post_init__(**kwargs)
7878

7979
@classmethod
80-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
80+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
8181
return SD3Transformer2DModelWrapper(model).eval()
8282

8383
@classmethod

src/optimum/rbln/diffusers/models/unets/unet_2d_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class ADDEMBEDDING:
171171
self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
172172

173173
@classmethod
174-
def wrap_model_if_needed(
174+
def _wrap_model_if_needed(
175175
cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
176176
) -> torch.nn.Module:
177177
if model.config.addition_embed_type == "text_time":

src/optimum/rbln/modeling.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -34,49 +34,6 @@
3434
logger = get_logger(__name__)
3535

3636

37-
def _get_dtype(
38-
cls,
39-
dtype: Optional[Union[str, torch.dtype, dict]],
40-
config: PretrainedConfig,
41-
) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
42-
dtype_orig = None
43-
44-
if dtype is not None:
45-
if isinstance(dtype, str):
46-
if dtype == "auto":
47-
if hasattr(config, "dtype") and config.dtype is not None:
48-
dtype = config.dtype
49-
else:
50-
dtype = torch.get_default_dtype()
51-
elif hasattr(torch, dtype):
52-
dtype = getattr(torch, dtype)
53-
config.dtype = dtype
54-
elif isinstance(dtype, torch.dtype):
55-
config.dtype = dtype
56-
elif isinstance(dtype, dict):
57-
for key, curr_dtype in dtype.items():
58-
if hasattr(config, key):
59-
value = getattr(config, key)
60-
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
61-
value.dtype = curr_dtype
62-
# main torch dtype for modules that aren't part of any sub-config
63-
dtype = dtype.get("")
64-
dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
65-
config.dtype = dtype
66-
if dtype is None:
67-
dtype = torch.float32
68-
else:
69-
raise ValueError(f"Invalid dtype: {dtype}")
70-
71-
dtype_orig = cls._set_default_dtype(dtype)
72-
else:
73-
# Use default dtype
74-
default_dtype = torch.get_default_dtype()
75-
config.dtype = default_dtype
76-
77-
return config, dtype, dtype_orig
78-
79-
8037
class RBLNModel(RBLNBaseModel):
8138
@classmethod
8239
def update_kwargs(cls, kwargs):
@@ -97,13 +54,13 @@ def save_torch_artifacts(
9754
pass
9855

9956
@classmethod
100-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
57+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
10158
# Wrap the model if needed.
10259
return model
10360

10461
@classmethod
10562
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
106-
model = cls.wrap_model_if_needed(model, rbln_config)
63+
model = cls._wrap_model_if_needed(model, rbln_config)
10764
rbln_compile_config = rbln_config.compile_cfgs[0]
10865
compiled_model = cls.compile(
10966
model,

src/optimum/rbln/transformers/modeling_generic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class RBLNTransformerEncoder(RBLNModel):
5959
rbln_dtype = "int64"
6060

6161
@classmethod
62-
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig) -> nn.Module:
62+
def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig) -> nn.Module:
6363
class TransformerEncoderWrapper(nn.Module):
6464
# Parameters to disable for RBLN compilation
6565
DISABLED_PARAMS = {"return_dict", "use_cache"}
@@ -268,7 +268,7 @@ class RBLNModelForDepthEstimation(RBLNImageModel):
268268
auto_model_class = AutoModelForDepthEstimation
269269

270270
@classmethod
271-
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
271+
def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
272272
class ImageModelWrapper(nn.Module):
273273
def __init__(self, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
274274
super().__init__()

src/optimum/rbln/transformers/models/bart/modeling_bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
4848
support_causal_attn = True
4949

5050
@classmethod
51-
def wrap_model_if_needed(self, model: PreTrainedModel, rbln_config: RBLNBartForConditionalGenerationConfig):
51+
def _wrap_model_if_needed(self, model: PreTrainedModel, rbln_config: RBLNBartForConditionalGenerationConfig):
5252
return BartWrapper(
5353
model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
5454
)

src/optimum/rbln/transformers/models/bert/modeling_bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
3535
rbln_model_input_names = ["input_ids", "attention_mask"]
3636

3737
@classmethod
38-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
38+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
3939
return BertModelWrapper(model, rbln_config)
4040

4141

0 commit comments

Comments
 (0)