Skip to content

Commit ab7a238

Browse files
authored
Use sync recv and modify sampler (#109)
* Modify sampler * Use sync send recv * revert benchmark serving
1 parent d20aa2d commit ab7a238

3 files changed

Lines changed: 13 additions & 19 deletions

File tree

gllm/dist_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ def recv_pp_data(src, shape, has_residual):
1616
hidden_states = torch.zeros(torch.Size(shape))
1717
if has_residual:
1818
residual = hidden_states.clone().detach()
19-
hidden_states_future = dist.irecv(hidden_states, src)
20-
residual_future = dist.irecv(residual, src)
21-
return hidden_states_future, residual_future, hidden_states, residual
19+
dist.recv(hidden_states, src)
20+
dist.recv(residual, src)
21+
return hidden_states, residual
2222
else:
23-
hidden_states_future = dist.irecv(hidden_states, src)
24-
return hidden_states_future, hidden_states
23+
dist.recv(hidden_states, src)
24+
return hidden_states
2525

2626
def send_obj_list(obj_list, dst):
2727
dist.send_object_list(obj_list, dst=dst)

gllm/layers/sampler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ def forward(self, logits: torch.Tensor, input_data: InputData):
1313
# top_p top_k
1414
logits = self._apply_top_k_top_p(logits, input_data.top_p, input_data.top_k)
1515
probs = torch.softmax(logits, dim=1)
16-
# q = torch.empty_like(probs)
17-
# q.exponential_()
18-
# return probs.div_(q).argmax(dim=1).cpu().numpy().tolist()
19-
return torch.multinomial(probs, 1).squeeze(1).cpu().numpy().tolist()
16+
17+
q = torch.empty_like(probs)
18+
q.exponential_()
19+
return probs.div_(q).argmax(dim=1).cpu().numpy().tolist()
20+
# return torch.multinomial(probs, 1).squeeze(1).cpu().numpy().tolist()
2021

2122
def _apply_top_k_top_p(
2223
self,

gllm/worker.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,10 @@ def forward_pp(self):
108108
hidden_states = None
109109
residual = None
110110
if self.ret_residual:
111-
input_data, (hidden_states_future, residual_future,
112-
hidden_states, residual) = self.run_queue[0]
113-
if not (hidden_states_future.is_completed() and residual_future.is_completed()):
114-
return
111+
input_data, (hidden_states, residual) = self.run_queue.popleft()
115112
else:
116-
input_data, (hidden_states_future,
117-
hidden_states) = self.run_queue[0]
118-
if not hidden_states_future.is_completed():
119-
return
120-
121-
self.run_queue.popleft()
113+
input_data, hidden_states = self.run_queue.popleft()
114+
122115
output = self.model_runner.step_once(
123116
input_data, hidden_states, residual)
124117
if is_output_rank():

0 commit comments

Comments
 (0)