Skip to content

Commit 498fb20

Browse files
authored
model: fix pytorch model (#328)
1 parent b9f7aa7 commit 498fb20

File tree

2 files changed

+13
-31
lines changed

2 files changed

+13
-31
lines changed

src/optimum/rbln/modeling.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -249,37 +249,11 @@ def get_pytorch_model(
249249
trust_remote_code: bool = False,
250250
# Some rbln-config should be applied before loading torch module (i.e. quantized llm)
251251
rbln_config: Optional[RBLNModelConfig] = None,
252-
dtype: Optional[Union[str, torch.dtype, dict]] = None,
253252
**kwargs,
254253
) -> "PreTrainedModel":
255254
kwargs = cls.update_kwargs(kwargs)
256255

257-
hf_class = cls.get_hf_class()
258-
259-
if dtype is not None:
260-
config = hf_class.config_class.from_pretrained(
261-
model_id,
262-
subfolder=subfolder,
263-
revision=revision,
264-
cache_dir=cache_dir,
265-
use_auth_token=use_auth_token,
266-
local_files_only=local_files_only,
267-
force_download=force_download,
268-
trust_remote_code=trust_remote_code,
269-
)
270-
271-
config, processed_dtype, dtype_orig = _get_dtype(
272-
cls=hf_class,
273-
dtype=dtype,
274-
config=config,
275-
)
276-
277-
kwargs["torch_dtype"] = processed_dtype
278-
279-
if dtype_orig is not None:
280-
hf_class._set_default_dtype(dtype_orig)
281-
282-
return hf_class.from_pretrained(
256+
return cls.get_hf_class().from_pretrained(
283257
model_id,
284258
subfolder=subfolder,
285259
revision=revision,

src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,22 +322,30 @@ def get_pytorch_model(
322322
*args,
323323
rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
324324
num_hidden_layers: Optional[int] = None,
325+
trust_remote_code: Optional[bool] = None,
326+
torch_dtype: Optional[torch.dtype] = None,
327+
dtype: Optional[torch.dtype] = None,
325328
**kwargs,
326329
) -> PreTrainedModel:
327330
if rbln_config and rbln_config.quantization:
328331
model = cls.get_quantized_model(model_id, *args, rbln_config=rbln_config, **kwargs)
329332
else:
333+
# TODO : resolve how to control PreTrainedConfig for hf_kwargs
330334
if num_hidden_layers is not None:
331-
trust_remote_code = kwargs.get("trust_remote_code", None)
332335
config, kwargs = AutoConfig.from_pretrained(
333-
model_id, return_unused_kwargs=True, num_hidden_layers=num_hidden_layers, **kwargs
336+
model_id,
337+
return_unused_kwargs=True,
338+
trust_remote_code=trust_remote_code,
339+
num_hidden_layers=num_hidden_layers,
340+
**kwargs,
334341
)
335342
if hasattr(config, "layer_types"):
336343
config.layer_types = config.layer_types[:num_hidden_layers]
337344
kwargs["config"] = config
338-
kwargs["trust_remote_code"] = trust_remote_code
339345

340-
model = super().get_pytorch_model(model_id, *args, **kwargs)
346+
model = super().get_pytorch_model(
347+
model_id, *args, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, dtype=dtype, **kwargs
348+
)
341349

342350
return model
343351

0 commit comments

Comments
 (0)