Skip to content

Commit d872eb2

Browse files
authored
Fix torch.cuda.set_device (#54)
1 parent e2e9851 commit d872eb2

3 files changed

Lines changed: 4 additions & 3 deletions

File tree

gllm/dist_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def recv_tensor(dtype, src):
2020
tensor_shape = [None for _ in range(dim[0])]
2121
dist.recv_object_list(tensor_shape,src)
2222
# recv tensor
23-
tensor = torch.zeros(torch.Size(tensor_shape),dtype=dtype,device=f'cuda:{dist.get_rank()}')
23+
tensor = torch.zeros(torch.Size(tensor_shape),dtype=dtype,device=f'cuda:{get_local_rank()}')
2424
dist.recv(tensor,src)
2525
return tensor
2626

@@ -33,7 +33,7 @@ def send_pp_data(output, dst):
3333
dist.isend(output,dst)
3434

3535
def recv_pp_data(src, dtype, shape, has_residual):
36-
hidden_states = torch.zeros(torch.Size(shape),dtype=dtype,device=f'cuda:{dist.get_rank()}')
36+
hidden_states = torch.zeros(torch.Size(shape),dtype=dtype,device=f'cuda:{get_local_rank()}')
3737
if has_residual:
3838
residual = hidden_states.clone().detach()
3939
hidden_states_future = dist.irecv(hidden_states,src)

gllm/model_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,5 +117,6 @@ def load_model(self, mp_load_progress=None):
117117
return model
118118
else:
119119
assert self.load_format == 'dummy'
120+
torch.set_default_device('cuda')
120121
model = model_type(self.config)
121122
return model

gllm/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def init(self):
4747

4848
init_dist(self.pp_size, self.local_rank, self.pp_rank, self.master_addr,
4949
self.master_port, self.assigned_layers)
50-
torch.cuda.set_device(f'cuda:{self.pp_rank}')
50+
torch.cuda.set_device(f'cuda:{self.local_rank}')
5151

5252
self.comm.init()
5353

0 commit comments

Comments
 (0)