For the segmentation approach, change the following behaviours as refactors:
- Reducers should become full custom autograd functions with extensible element-wise functions f(x) and g(x), to be applied just before, and just after aggregation (sum/mean)
- seen_masks currently construct a full H,W map of seen indices. See if this can be done tile-wise
For the segmentation approach, change the following behaviours as refactors: