```python onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float16) onehot.scatter_(1, output_seg, 1) ``` The above implementation is very likely to be faster than the current implementation of `convert_ids_to_logits()`.
The above implementation is very likely to be faster than the current implementation of
convert_ids_to_logits().