Skip to content

Commit 2f7dc22

Browse files
committed
compatible to 4.57.1
1 parent e96b727 commit 2f7dc22

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

examples/text2text-generation/run_gpt_oss.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import typing
33

44
import fire
5-
from transformers import AutoTokenizer, GptOssForCausalLM
5+
import torch
6+
from transformers import AutoConfig, AutoTokenizer, GptOssForCausalLM
67

78
from optimum.rbln import RBLNGptOssForCausalLM
89

@@ -49,6 +50,12 @@ def main(
4950
diff: bool = False,
5051
n_layers: int = 2,
5152
):
53+
target_config = AutoConfig.from_pretrained(model_id)
54+
target_config._attn_implementation = "eager"
55+
target_config.num_hidden_layers = n_layers
56+
target_config.layer_types = target_config.layer_types[:n_layers]
57+
# target_config.dtype = torch.float32
58+
5259
if from_transformers:
5360
model = RBLNGptOssForCausalLM.from_pretrained(
5461
model_id,
@@ -57,8 +64,8 @@ def main(
5764
rbln_max_seq_len=max_seq_len,
5865
rbln_tensor_parallel_size=tensor_parallel_size,
5966
rbln_kvcache_partition_len=kvcache_partition_len,
60-
num_hidden_layers=n_layers,
61-
dtype="float32",
67+
config=target_config,
68+
dtype=torch.float32,
6269
)
6370
model.save_pretrained(os.path.basename(model_id))
6471
else:
@@ -95,11 +102,7 @@ def main(
95102
logits = rbln_outputs.logits
96103

97104
if diff:
98-
golden_model = GptOssForCausalLM.from_pretrained(
99-
model_id,
100-
num_hidden_layers=n_layers,
101-
_attn_implementation="eager",
102-
)
105+
golden_model = GptOssForCausalLM.from_pretrained(model_id, config=target_config, dtype=torch.float32)
103106
golden_outputs = golden_model.generate(
104107
**inputs,
105108
do_sample=False,

0 commit comments

Comments
 (0)