-
Notifications
You must be signed in to change notification settings - Fork 29
Description
Hi, thanks again for the amazing work!
I'm back with another issue:
I'm still on v0.1.3 and not on 1.0.0, so perhaps this is not relevant anymore.
When loading perch_v2 via perch-hoplite zoo and running tf.GradientTape, gradients are explicitly disabled inside the exported model (jax2tf conversion appears to use with_gradient=False).
This blocks gradient-based explainability methods (e.g., SmoothGrad, input gradients), even when the downstream linear classifier head is implemented in-graph.
Gradient call fails with:
LookupError: Gradient explicitly disabled. Reason: b'The jax2tf-converted function does not support gradients. Use with_gradient parameter to enable gradients'
CAM and occlusion work, but true gradient attribution cannot be used.
Requested fix / enhancement
Provide a gradient-enabled Perch v2 export (or optional variant) with jax2tf ... with_gradient=True.
If this is impossible for whatever reason, then that's okay, it would just be nice to have this option.
P.s. I mentioned the shape incongruity between the preprint (5, 3, 1536) and the kaggle page (16,4,1536), and you confirmed this was a little error and the preprint will be adjusted before the final submission. Just wanted to note that the kaggle page contains one extra mention of the (5, 3, 1536) shape, which is incorrect. Under 'Model details -> Model description -> (5,3,1536) (incorrect), under 'Model variations -> Model description -> (16,4,1536) (correct).