@@ -90,11 +90,6 @@ class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
9090
9191 attn_weights = (query * self .scale) @ key.transpose(- 2 , - 1 )
9292
93- if self .use_rel_pos:
94- attn_weights = self .add_decomposed_rel_pos(
95- attn_weights, query, self .rel_pos_h, self .rel_pos_w, (height, width), (height, width)
96- )
97-
9893 attn_weights = torch.nn.functional.softmax(attn_weights, dtype = torch.float32, dim = - 1 ).to(query.dtype)
9994 attn_probs = nn.functional.dropout(attn_weights, p = self .dropout, training = self .training)
10095 attn_output = (attn_probs @ value).reshape(batch_size, self .num_attention_heads, height, width, - 1 )
@@ -114,13 +109,14 @@ Load the model with [`~PreTrainedModel.from_pretrained`].
114109
115110``` py
116111from transformers import SamModel
117- from transformers.models.sam import modeling_sam
118-
119- # replace the attention class in the modeling_sam module
120- modeling_sam.SamVisionAttention = SamVisionAttentionSplit
121112
122113# load the pretrained SAM model
123114model = SamModel.from_pretrained(" facebook/sam-vit-base" )
115+
116+ # replace the attention class in the vision_encoder module
117+ for layer in model.vision_encoder.layers:
118+ if hasattr (layer, " attn" ):
119+ layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)
124120```
125121
126122## LoRA
@@ -138,7 +134,7 @@ config = LoraConfig(
138134 # apply LoRA to q and v
139135 target_modules = [" q" , " v" ],
140136 lora_dropout = 0.1 ,
141- task_type = " mask-generation "
137+ task_type = " FEATURE_EXTRACTION "
142138)
143139```
144140
@@ -152,5 +148,5 @@ Call [print_trainable_parameters](https://huggingface.co/docs/peft/package_refer
152148
153149``` py
154150model.print_trainable_parameters()
155- " trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447 "
151+ " trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256 "
156152```
0 commit comments