Skip to content

Commit 0becbdb

Browse files
authored
Replace one_hot with scatter_ in convert_ids_to_logits() speedup (#211)
* Replace `one_hot` with `scatter_` in `convert_ids_to_logits()` for 3-8x speedup (#204) * Use tuple unpacking for shape construction
1 parent 5917ae3 commit 0becbdb

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

mipcandy/data/convertion.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Literal
22

33
import torch
4-
from torch import nn
54

65
from mipcandy.common import Normalize
76

@@ -15,7 +14,9 @@ def convert_ids_to_logits(ids: torch.Tensor, d: Literal[1, 2, 3], num_classes: i
1514
ids = ids.squeeze(1)
1615
else:
1716
raise ValueError(f"`ids` should be {d} dimensional or {d + 1} dimensional with single channel")
18-
return nn.functional.one_hot(ids.long(), num_classes).movedim(-1, 1).contiguous().float()
17+
logits = torch.zeros((ids.shape[0], num_classes, *ids.shape[1:]), device=ids.device, dtype=torch.float32)
18+
logits.scatter_(1, ids.unsqueeze(1).long(), 1)
19+
return logits
1920

2021

2122
def convert_logits_to_ids(logits: torch.Tensor, *, channel_dim: int = 1) -> torch.Tensor:

0 commit comments

Comments
 (0)