Skip to content

Commit 2693eac

Browse files
committed
x4 speed-up of loglikelihood requests
1 parent a9ca563 commit 2693eac

File tree

1 file changed

+5
-15
lines changed

1 file changed

+5
-15
lines changed

examples/text-generation/model_adapter.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,7 @@ def max_length(self) -> int:
170170

171171
@property
172172
def device(self):
173-
# We need to do padding ourselves, otherwise we'll end up with recompilations
174-
# Returning 'cpu' to keep tensors on CPU in lm_eval code
175-
return "cpu"
173+
return "hpu"
176174

177175
@max_length.setter
178176
def max_length(self, value: int) -> None:
@@ -202,12 +200,10 @@ def _model_call(self, inps: torch.Tensor) -> torch.Tensor:
202200
pad_token_id = getattr(self._model.config, "pad_token_id", 0)
203201
inps = F.pad(inps, (0, padding_length), value=pad_token_id)
204202
eval_logger.debug(f"Padded input from {seq_length} to {bucket_length} (pad={padding_length})")
205-
logits = self._model(inps.to(self.device_), **self.model_inputs)["logits"].cpu()
203+
logits = self._model(inps.to(self.device_), **self.model_inputs)["logits"]
206204

207205
if self.options.static_shapes and padding_length > 0:
208206
logits = logits[:, :-padding_length, :]
209-
logits = logits.to(torch.float32)
210-
211207
return logits
212208

213209
def generate_until(self, requests: list[Instance], disable_tqdm: bool = False) -> list[str]:
@@ -232,16 +228,10 @@ def _model_generate(
232228
Patched method
233229
source: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.9.1/lm_eval/models/huggingface.py#L951
234230
"""
235-
# temperature = 0.0 if not set
236-
# if do_sample is false and temp==0.0:
237-
# remove temperature, as do_sample=False takes care of this
238-
# and we don't want a warning from HF
239231
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
240232
do_sample = generation_kwargs.get("do_sample")
241-
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
242233
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
243234
generation_kwargs["do_sample"] = do_sample = False
244-
245235
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
246236
generation_kwargs.pop("temperature")
247237
if self.options.static_shapes:
@@ -257,10 +247,10 @@ def _model_generate(
257247
generation_kwargs["attention_mask"], (0, padding_length), value=0
258248
)
259249
# move context & attention_mask to hpu
260-
context = context.to("hpu")
261-
generation_kwargs["attention_mask"] = generation_kwargs["attention_mask"].to("hpu")
250+
context = context.to(self.device)
251+
generation_kwargs["attention_mask"] = generation_kwargs["attention_mask"].to(self.device)
262252
with torch.autocast(
263-
device_type="hpu",
253+
device_type=self.device,
264254
dtype=self.mixed_precision_dtype,
265255
enabled=self.mixed_precision_dtype is not None,
266256
):

0 commit comments

Comments
 (0)