Skip to content

Output class probabilites for segmetation tasks #338

Open
@geoffreyjdawson

Description

Currently when we performing inference on a segmentation task we can only output the most probable class. It would be good to have an option to also output the probabilities for each class as separate bands i.e. not perform the argmax.

y_hat = y_hat.argmax(dim=1)

To do this would also mean that we would need an option for a multi-band output here

https://github.com/IBM/terratorch/blob/30dfdf15716edcc9295559aec4327dee62b0d956/terratorch/cli_tools.py#L87C1-L103C1

I have tried this and it works on a single tile, but not with tiled inference.

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions