Skip to content

Commit 5917ae3

Browse files
authored
Refactored save_preview method for cleaner output handling in segmentation presets. (#215) (#216)
1 parent 13c3d0e commit 5917ae3

1 file changed

Lines changed: 1 addition & 3 deletions

File tree

mipcandy/presets/segmentation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def apply_non_linearity(self, x: torch.Tensor, channel_dim: int) -> torch.Tensor
7373
@override
7474
def save_preview(self, image: torch.Tensor, label: torch.Tensor, output: torch.Tensor, *,
7575
quality: float = .75) -> None:
76-
output = self.apply_non_linearity(output, 0)
77-
if output.shape[0] != 1:
78-
output = convert_logits_to_ids(output, channel_dim=0).int()
76+
output = convert_logits_to_ids(self.apply_non_linearity(output, 0), channel_dim=0)
7977
self._save_preview(image, "input", quality)
8078
self._save_preview(label.int(), "label", quality, is_label=True)
8179
self._save_preview(output, "prediction", quality, is_label=True)

0 commit comments

Comments
 (0)