Skip to content

Scatter is potentially faster than one-hot #204

@ATATC

Description

@ATATC
onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float16)
onehot.scatter_(1, output_seg, 1)

The above implementation is very likely to be faster than the current implementation of convert_ids_to_logits().

Metadata

Metadata

Assignees

Labels

code reviewCode review or comment

Type

No fields configured for Task.

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions