Skip to content

Commit 8a929b2

Browse files
committed
Unified with upstream
1 parent 2693eac commit 8a929b2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/text-generation/model_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def max_length(self) -> int:
170170

171171
@property
172172
def device(self):
173-
return "hpu"
173+
return torch.device("hpu")
174174

175175
@max_length.setter
176176
def max_length(self, value: int) -> None:
@@ -200,7 +200,7 @@ def _model_call(self, inps: torch.Tensor) -> torch.Tensor:
200200
pad_token_id = getattr(self._model.config, "pad_token_id", 0)
201201
inps = F.pad(inps, (0, padding_length), value=pad_token_id)
202202
eval_logger.debug(f"Padded input from {seq_length} to {bucket_length} (pad={padding_length})")
203-
logits = self._model(inps.to(self.device_), **self.model_inputs)["logits"]
203+
logits = self._model(inps.to(self.device), **self.model_inputs)["logits"]
204204

205205
if self.options.static_shapes and padding_length > 0:
206206
logits = logits[:, :-padding_length, :]

0 commit comments

Comments
 (0)