Skip to content

Commit 9770afa

Browse files
committed
fix review comments
Signed-off-by: Matrix YAO <matrix.yao@intel.com>
1 parent 83d2a9d commit 9770afa

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/accelerate/big_modeling.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,10 @@ def init_on_device(device: torch.device, include_buffers: Optional[bool] = None)
108108
109109
```python
110110
import torch.nn as nn
111-
from accelerate import Accelerator, init_on_device
111+
from accelerate import init_on_device
112112
113-
accelerator = Accelerator()
114-
115-
with init_on_device(device=torch.device(accelerator.device)):
113+
# init model on specified device(e.g., "cuda", "xpu" and so on)
114+
with init_on_device(device=torch.device("cuda")):
116115
tst = nn.Linear(100, 100) # on specified device
117116
```
118117
"""

0 commit comments

Comments
 (0)