Linear probe exposure#6
Conversation
Fix classification labels for batch size 1
Sllambias
left a comment
There was a problem hiding this comment.
Hi Luka,
Appreciate the PR's. I agree with the problems but I have two comments about this. One of them involves some changes that I can also help with contributing if you'd like. The other comment is just a slight uncertainty that I haven't had the chance to test myself yet.
| self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||
|
|
||
| feature_dim = self.model.decoder.fc.in_features | ||
| try: |
There was a problem hiding this comment.
I see the problem here. I think the solution should support arbitrary networks and for that I think the solution is to (1) in the model configs add a linear_probe_feature_dim_layer_name: "head.in_features" arg (2) add an "feature_dim_layer=cfg.model.linear_probe_feature_dim_layer_name" arg to the linear_probe_module and (3) figure out how to parse that for networks of arbitrary depths as it will be a string so that we in essence get feature_dim=self.model.get(feature_dim_layer).get(in_features)
There was a problem hiding this comment.
I like your approach, it's general
My "fix" is more of a quickfix
There was a problem hiding this comment.
Okay super. Do you want to implement it or should I/we?
On our end I think it will be possible later this week
There was a problem hiding this comment.
I won't have bandwidth to tackle this until mid next week :/ I'd appreciate you picking it up - ofc if you have time - otherwise we live in the quickfix world until someone gets time to do it :)
There was a problem hiding this comment.
@lukasugar I moved it to #10 so I could commit changes. If it's relevant for you to be credited with the merged PR we can migrate the final changes back to this branch before merging once the solution is confirmed to work :)
Adding
asp_linear_probeas a script inpyproject.toml