22import typing
33
44import fire
5- from transformers import AutoTokenizer , GptOssForCausalLM
5+ import torch
6+ from transformers import AutoConfig , AutoTokenizer , GptOssForCausalLM
67
78from optimum .rbln import RBLNGptOssForCausalLM
89
@@ -49,6 +50,12 @@ def main(
4950 diff : bool = False ,
5051 n_layers : int = 2 ,
5152):
53+ target_config = AutoConfig .from_pretrained (model_id )
54+ target_config ._attn_implementation = "eager"
55+ target_config .num_hidden_layers = n_layers
56+ target_config .layer_types = target_config .layer_types [:n_layers ]
57+ # target_config.dtype = torch.float32
58+
5259 if from_transformers :
5360 model = RBLNGptOssForCausalLM .from_pretrained (
5461 model_id ,
@@ -57,8 +64,8 @@ def main(
5764 rbln_max_seq_len = max_seq_len ,
5865 rbln_tensor_parallel_size = tensor_parallel_size ,
5966 rbln_kvcache_partition_len = kvcache_partition_len ,
60- num_hidden_layers = n_layers ,
61- dtype = " float32" ,
67+ config = target_config ,
68+ dtype = torch . float32 ,
6269 )
6370 model .save_pretrained (os .path .basename (model_id ))
6471 else :
@@ -95,11 +102,7 @@ def main(
95102 logits = rbln_outputs .logits
96103
97104 if diff :
98- golden_model = GptOssForCausalLM .from_pretrained (
99- model_id ,
100- num_hidden_layers = n_layers ,
101- _attn_implementation = "eager" ,
102- )
105+ golden_model = GptOssForCausalLM .from_pretrained (model_id , config = target_config , dtype = torch .float32 )
103106 golden_outputs = golden_model .generate (
104107 ** inputs ,
105108 do_sample = False ,
0 commit comments