Skip to content

Commit 867bf26

Browse files
faaanyxvyv99
authored andcommitted
[doc] fix bugs in how_to_hack_models.md (huggingface#38198)
fix several bugs
1 parent de0bba0 commit 867bf26

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

docs/source/en/how_to_hack_models.md

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,6 @@ class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
9090

9191
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
9292

93-
if self.use_rel_pos:
94-
attn_weights = self.add_decomposed_rel_pos(
95-
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
96-
)
97-
9893
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
9994
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
10095
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
@@ -114,13 +109,14 @@ Load the model with [`~PreTrainedModel.from_pretrained`].
114109

115110
```py
116111
from transformers import SamModel
117-
from transformers.models.sam import modeling_sam
118-
119-
# replace the attention class in the modeling_sam module
120-
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
121112

122113
# load the pretrained SAM model
123114
model = SamModel.from_pretrained("facebook/sam-vit-base")
115+
116+
# replace the attention class in the vision_encoder module
117+
for layer in model.vision_encoder.layers:
118+
if hasattr(layer, "attn"):
119+
layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)
124120
```
125121

126122
## LoRA
@@ -138,7 +134,7 @@ config = LoraConfig(
138134
# apply LoRA to q and v
139135
target_modules=["q", "v"],
140136
lora_dropout=0.1,
141-
task_type="mask-generation"
137+
task_type="FEATURE_EXTRACTION"
142138
)
143139
```
144140

@@ -152,5 +148,5 @@ Call [print_trainable_parameters](https://huggingface.co/docs/peft/package_refer
152148

153149
```py
154150
model.print_trainable_parameters()
155-
"trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447"
151+
"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256"
156152
```

0 commit comments

Comments
 (0)