@@ -15,22 +15,26 @@ def __init__(self, ssm_model_name: str, num_workers: int = 2, device: str = 'cud
1515 from transformers import AutoModelForCausalLM
1616
1717 self .num_workers = num_workers
18- self .device = torch .device (device )
1918
2019 self .ssms = []
2120 self .streams = []
22- for _ in range (num_workers ):
21+ self .devices = []
22+
23+ for i in range (num_workers ):
24+ device_i = torch .device (f'cuda:{ i } ' )
25+ self .devices .append (device_i )
26+
2327 ssm = AutoModelForCausalLM .from_pretrained (
2428 ssm_model_name ,
2529 torch_dtype = torch .float16 )
26- ssm = ssm .to (self . device )
30+ ssm = ssm .to (device_i )
2731 ssm .eval ()
2832 self .ssms .append (ssm )
29- self .streams .append (torch .cuda .Stream (device = self . device ))
33+ self .streams .append (torch .cuda .Stream (device = device_i ))
3034
3135 with torch .no_grad ():
32- dummy = torch . ones ( 1 , 8 , dtype = torch . long , device = self .device )
33- for ssm in self .ssms :
36+ for i , ssm in enumerate ( self .ssms ):
37+ dummy = torch . ones ( 1 , 8 , dtype = torch . long , device = self .devices [ i ])
3438 ssm (dummy , attention_mask = torch .ones_like (dummy ))
3539
3640 def build_trees_parallel (
@@ -49,10 +53,11 @@ def build_trees_parallel(
4953 def worker_fn (worker_idx : int , batch_indices : List [int ]):
5054 ssm = self .ssms [worker_idx ]
5155 stream = self .streams [worker_idx ]
56+ device = self .devices [worker_idx ]
5257
5358 with torch .cuda .stream (stream ):
5459 results = self ._build_trees_batched (
55- batch_indices , input_ids , seq_lengths , ssm , beam_width , max_depth
60+ batch_indices , input_ids , seq_lengths , ssm , beam_width , max_depth , device
5661 )
5762 for batch_idx , tree in results :
5863 all_results [batch_idx ] = tree
@@ -72,8 +77,9 @@ def worker_fn(worker_idx: int, batch_indices: List[int]):
7277 t .join ()
7378
7479 # 同步所有 streams
75- for stream in self .streams :
76- stream .synchronize ()
80+ for i , stream in enumerate (self .streams ):
81+ with torch .cuda .device (self .devices [i ]):
82+ stream .synchronize ()
7783
7884 return all_results
7985
@@ -85,9 +91,10 @@ def _build_trees_batched(
8591 ssm ,
8692 beam_width : int ,
8793 max_depth : int ,
94+ device : torch .device ,
8895 ) -> List :
8996
90- pad_token_id = getattr (ssm .config , 'pad_token_id' , 0 )
97+ pad_token_id = getattr (ssm .config , 'pad_token_id' , None ) or 0
9198
9299 trees = {}
93100 valid_inputs = {}
@@ -96,7 +103,7 @@ def _build_trees_batched(
96103
97104 for batch_idx in batch_indices :
98105 actual_len = seq_lengths [batch_idx ].item ()
99- valid_input_ids = input_ids [batch_idx , :actual_len ]
106+ valid_input_ids = input_ids [batch_idx , :actual_len ]. to ( device )
100107 valid_inputs [batch_idx ] = valid_input_ids
101108 prefix_lengths [batch_idx ] = max (actual_len - 1 , 0 )
102109
@@ -118,23 +125,23 @@ def _build_trees_batched(
118125 if pf_len > 0 :
119126 prefix = valid_inputs [batch_idx ][:- 1 ]
120127 else :
121- prefix = torch .tensor ([], dtype = torch .long , device = self . device )
128+ prefix = torch .tensor ([], dtype = torch .long , device = device )
122129
123130 pad_len = max_prefix_len - pf_len
124131
125132 if pf_len > 0 :
126133 padded_prefixes .append (torch .cat ([
127- torch .full ((pad_len ,), pad_token_id , dtype = torch .long , device = self . device ),
134+ torch .full ((pad_len ,), pad_token_id , dtype = torch .long , device = device ),
128135 prefix
129136 ]))
130137 else :
131138 padded_prefixes .append (
132- torch .full ((max_prefix_len ,), pad_token_id , dtype = torch .long , device = self . device )
139+ torch .full ((max_prefix_len ,), pad_token_id , dtype = torch .long , device = device )
133140 )
134141
135142 prefix_masks .append (torch .cat ([
136- torch .zeros (pad_len , dtype = torch .long , device = self . device ),
137- torch .ones (pf_len , dtype = torch .long , device = self . device )
143+ torch .zeros (pad_len , dtype = torch .long , device = device ),
144+ torch .ones (pf_len , dtype = torch .long , device = device )
138145 ]))
139146
140147 batch_prefixes = torch .stack (padded_prefixes )
@@ -166,7 +173,7 @@ def _build_trees_batched(
166173
167174 for node in tree .get_nodes_at_depth (depth ):
168175 path = node .get_path_from_root ()
169- path_tokens = torch .tensor ([root_token ] + path , dtype = torch .long , device = self . device )
176+ path_tokens = torch .tensor ([root_token ] + path , dtype = torch .long , device = device )
170177 all_paths .append (path_tokens )
171178 node_mapping .append ((batch_idx , node ))
172179 cache_indices .append (idx_map [batch_idx ])
@@ -180,8 +187,8 @@ def _build_trees_batched(
180187 total_mask_len = max_pf_len + max_path_len
181188
182189 # 预分配
183- batch_paths = torch .full ((num_nodes , max_path_len ), pad_token_id , dtype = torch .long , device = self . device )
184- batch_path_masks = torch .zeros ((num_nodes , total_mask_len ), dtype = torch .long , device = self . device )
190+ batch_paths = torch .full ((num_nodes , max_path_len ), pad_token_id , dtype = torch .long , device = device )
191+ batch_path_masks = torch .zeros ((num_nodes , total_mask_len ), dtype = torch .long , device = device )
185192
186193 # 填充
187194 for i , path in enumerate (all_paths ):
@@ -212,7 +219,7 @@ def _build_trees_batched(
212219 all_logits = outputs .logits [:, - 1 , :]
213220
214221 t_forward += time .perf_counter () - t0
215- t0 = time .perf_counter ()
222+ t0 = time .perf_counter ()
216223 # 批量 topk
217224 _ , all_top_k_indices = torch .topk (all_logits , k = beam_width , dim = - 1 )
218225 all_probs = torch .softmax (all_logits , dim = - 1 )
0 commit comments