Skip to content

Perch v2 has gradients disabled (PreventGradient), blocking SmoothGrad / gradient-based attribution #90

@MasBioCoding

Description

@MasBioCoding

Hi, thanks again for the amazing work!
I'm back with another issue:

I'm still on v0.1.3 and not on 1.0.0, so perhaps this is not relevant anymore.

When loading perch_v2 via perch-hoplite zoo and running tf.GradientTape, gradients are explicitly disabled inside the exported model (jax2tf conversion appears to use with_gradient=False).

This blocks gradient-based explainability methods (e.g., SmoothGrad, input gradients), even when the downstream linear classifier head is implemented in-graph.

Gradient call fails with:
LookupError: Gradient explicitly disabled. Reason: b'The jax2tf-converted function does not support gradients. Use with_gradient parameter to enable gradients'

CAM and occlusion work, but true gradient attribution cannot be used.

Requested fix / enhancement
Provide a gradient-enabled Perch v2 export (or optional variant) with jax2tf ... with_gradient=True.

If this is impossible for whatever reason, then that's okay, it would just be nice to have this option.

P.s. I mentioned the shape incongruity between the preprint (5, 3, 1536) and the kaggle page (16,4,1536), and you confirmed this was a little error and the preprint will be adjusted before the final submission. Just wanted to note that the kaggle page contains one extra mention of the (5, 3, 1536) shape, which is incorrect. Under 'Model details -> Model description -> (5,3,1536) (incorrect), under 'Model variations -> Model description -> (16,4,1536) (correct).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions