Skip to content

Commit dbe544f

Browse files
committed
review: extend and use common dtype helper
1 parent 4668740 commit dbe544f

File tree

3 files changed

+8
-17
lines changed

3 files changed

+8
-17
lines changed

optimum/neuron/models/inference/nxd/backend/config.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,12 @@
1818
import torch
1919

2020
from .....configuration_utils import NeuronConfig, register_neuron_config
21+
from .....utils import map_torch_dtype
2122

2223

2324
NEURON_CONFIG_FILE = "neuron_config.json"
2425

2526

26-
def to_torch_dtype(dtype_str: str) -> torch.dtype:
27-
dtype_mapping = {
28-
"float32": torch.float32,
29-
"float16": torch.float16,
30-
"bfloat16": torch.bfloat16,
31-
"fp32": torch.float32,
32-
"fp16": torch.float16,
33-
"bf16": torch.bfloat16,
34-
}
35-
assert dtype_str in dtype_mapping, f"Unsupported dtype: {dtype_str}"
36-
return dtype_mapping[dtype_str]
37-
38-
3927
def to_dict(obj):
4028
if type(obj) is dict:
4129
return {k: to_dict(v) for k, v in obj.items()}
@@ -131,15 +119,15 @@ def __init__(
131119
self.tp_degree = tp_degree
132120
self.torch_dtype = torch_dtype
133121
if isinstance(self.torch_dtype, str):
134-
self.torch_dtype = to_torch_dtype(self.torch_dtype)
122+
self.torch_dtype = map_torch_dtype(self.torch_dtype)
135123
self.n_active_tokens = self.sequence_length if n_active_tokens is None else n_active_tokens
136124
self.output_logits = output_logits
137125

138126
self.padding_side = padding_side
139127

140128
self.rpl_reduce_dtype = torch_dtype if rpl_reduce_dtype is None else rpl_reduce_dtype
141129
if isinstance(self.rpl_reduce_dtype, str):
142-
self.rpl_reduce_dtype = to_torch_dtype(self.rpl_reduce_dtype)
130+
self.rpl_reduce_dtype = map_torch_dtype(self.rpl_reduce_dtype)
143131

144132
# fallback to sequence_length is for compatibility with vllm
145133
self.max_context_length = max_context_length

optimum/neuron/utils/misc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,9 @@ def map_torch_dtype(dtype: Union[str, torch.dtype]):
631631
"float64": torch.float64,
632632
"int32": torch.int32,
633633
"int64": torch.int64,
634+
"bf16": torch.bfloat16,
635+
"fp16": torch.float16,
636+
"fp32": torch.float32,
634637
}
635638

636639
if isinstance(dtype, str) and dtype in dtype_mapping:

tests/decoder/test_decoder_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from transformers import AutoModelForCausalLM
2020

2121
from optimum.neuron import NeuronModelForCausalLM
22-
from optimum.neuron.models.inference.nxd.backend.config import to_torch_dtype
2322
from optimum.neuron.models.inference.nxd.llama.modeling_llama import LlamaNxDModelForCausalLM
23+
from optimum.neuron.utils import map_torch_dtype
2424
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx
2525

2626

@@ -53,7 +53,7 @@ def check_neuron_model(neuron_model, batch_size=None, sequence_length=None, num_
5353
if hasattr(neuron_config, "auto_cast_type"):
5454
assert neuron_config.auto_cast_type == auto_cast_type
5555
elif hasattr(neuron_config, "torch_dtype"):
56-
assert neuron_config.torch_dtype == to_torch_dtype(auto_cast_type)
56+
assert neuron_config.torch_dtype == map_torch_dtype(auto_cast_type)
5757

5858

5959
def _test_decoder_export_save_reload(

0 commit comments

Comments
 (0)