Skip to content

Commit 8080aa6

Browse files
committed
fix syntax errors on roberta and lora
1 parent c087adc commit 8080aa6

File tree

6 files changed

+104
-27
lines changed

6 files changed

+104
-27
lines changed

exps/run_en_ro.sh

+14-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ ffn_adapter_init_option="lora"
3838
ffn_adapter_scalar="4"
3939
ffn_bn=512 # ffn bottleneck dim
4040

41+
42+
# lora params are not set
43+
if [ -z ${lora_alpha+x} ];
44+
then
45+
lora_alpha=0
46+
lora_init="lora"
47+
lora_dropout=0
48+
fi
49+
4150
# set to 1 for debug mode which only
4251
# uses 1600 training examples
4352
debug=0
@@ -48,13 +57,13 @@ report_to="none"
4857
label_smoothing_factor=0.1
4958
weight_decay=0.01
5059

51-
# the prefix tuning baseline prefers the
60+
# the prefix tuning baseline prefers the
5261
# commented hyperparam
5362
# label_smoothing_factor=0
5463
# weight_decay=0
5564

5665
# note that the bsz argument is only effective at evaluation but
57-
# does not influence the training -- it is overridden by
66+
# does not influence the training -- it is overridden by
5867
# max_tokens_per_batch
5968
bsz=10
6069
max_steps=50000
@@ -128,6 +137,9 @@ python -u examples/pytorch/translation/run_translation.py \
128137
--adam_epsilon 1e-6 \
129138
--dropout 0.1 \
130139
--attention_dropout 0.0 \
140+
--lora_alpha ${lora_alpha} \
141+
--lora_dropout ${lora_dropout} \
142+
--lora_init ${lora_init} \
131143
--attn_mode ${attn_mode} \
132144
--attn_option ${attn_option} \
133145
--attn_composition ${attn_composition} \

exps/run_glue.sh

+33
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,35 @@ ffn_adapter_init_option="lora"
5050
ffn_adapter_scalar="2"
5151
ffn_bn=16 # ffn bottleneck dim
5252

53+
# ----- lora -----
54+
# attn_mode="lora"
55+
# attn_option="none"
56+
# attn_composition="add"
57+
# attn_bn=16
58+
59+
# set ffn_mode to be 'lora' to use
60+
# lora at ffn as well
61+
62+
# ffn_mode="lora"
63+
# ffn_option="none"
64+
# ffn_adapter_layernorm_option="none"
65+
# ffn_adapter_init_option="bert"
66+
# ffn_adapter_scalar="1"
67+
# ffn_bn=16
68+
69+
# lora_alpha=32
70+
# lora_dropout=0.1
71+
# lora_init="lora"
72+
73+
74+
# lora params are not set
75+
if [ -z ${lora_alpha+x} ];
76+
then
77+
lora_alpha=0
78+
lora_init="lora"
79+
lora_dropout=0
80+
fi
81+
5382
# set to 1 for debug mode which only
5483
# uses 1600 training examples
5584
debug=0
@@ -104,6 +133,7 @@ then
104133
fi
105134

106135

136+
107137
# for seed in "${seed_list[@]}"; do
108138

109139
exp_name=glue.${TASK_NAME}.am_${attn_mode}.ao_${attn_option}.fm_${ffn_mode}
@@ -133,6 +163,9 @@ python -u examples/pytorch/text-classification/run_glue.py \
133163
--adam_beta1 0.9 \
134164
--adam_beta2 0.98 \
135165
--adam_epsilon 1e-6 \
166+
--lora_alpha ${lora_alpha} \
167+
--lora_dropout ${lora_dropout} \
168+
--lora_init ${lora_init} \
136169
--attn_mode ${attn_mode} \
137170
--attn_option ${attn_option} \
138171
--attn_composition ${attn_composition} \

exps/run_xsum.sh

+38-6
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ ffn_adapter_init_option="lora"
3636
ffn_adapter_scalar="4"
3737
ffn_bn=512 # ffn bottleneck dim
3838

39-
# ----- prefix tuning baseline -----
39+
# ----- prefix tuning baseline -----
4040
# attn_mode="prefix"
4141
# attn_option="concat"
4242
# attn_composition="add"
@@ -49,7 +49,7 @@ ffn_bn=512 # ffn bottleneck dim
4949
# ffn_adapter_scalar="4"
5050
# ffn_bn=512 # ffn bottleneck dim
5151

52-
# ----- Houlsby Adapter -----
52+
# ----- Houlsby Adapter -----
5353
# attn_mode="adapter"
5454
# attn_option="sequential"
5555
# attn_composition="add"
@@ -63,7 +63,7 @@ ffn_bn=512 # ffn bottleneck dim
6363
# ffn_bn=200 # ffn bottleneck dim
6464

6565

66-
# ----- FFN Scaled Parallel Adapter -----
66+
# ----- FFN Scaled Parallel Adapter -----
6767
# attn_mode="none"
6868
# attn_option="parallel"
6969
# attn_composition="add"
@@ -76,7 +76,7 @@ ffn_bn=512 # ffn bottleneck dim
7676
# ffn_adapter_scalar="4"
7777
# ffn_bn=512 # ffn bottleneck dim
7878

79-
# ----- Prompt Tuning -----
79+
# ----- Prompt Tuning -----
8080
# attn_mode="prompt_tuning"
8181
# attn_option="parallel"
8282
# attn_composition="add"
@@ -89,7 +89,7 @@ ffn_bn=512 # ffn bottleneck dim
8989
# ffn_adapter_scalar="4"
9090
# ffn_bn=512 # ffn bottleneck dim
9191

92-
# ----- bitfit -----
92+
# ----- bitfit -----
9393
# attn_mode="bitfit"
9494
# attn_option="parallel"
9595
# attn_composition="add"
@@ -102,6 +102,35 @@ ffn_bn=512 # ffn bottleneck dim
102102
# ffn_adapter_scalar="4"
103103
# ffn_bn=512 # ffn bottleneck dim
104104

105+
# ----- lora -----
106+
# attn_mode="lora"
107+
# attn_option="none"
108+
# attn_composition="add"
109+
# attn_bn=16
110+
111+
# # set ffn_mode to be 'lora' to use
112+
# # lora at ffn as well
113+
114+
# ffn_mode="none"
115+
# ffn_option="none"
116+
# ffn_adapter_layernorm_option="none"
117+
# ffn_adapter_init_option="bert"
118+
# ffn_adapter_scalar="1"
119+
# ffn_bn=16
120+
121+
# lora_alpha=32
122+
# lora_dropout=0.1
123+
# lora_init="lora"
124+
125+
126+
# lora params are not set
127+
if [ -z ${lora_alpha+x} ];
128+
then
129+
lora_alpha=0
130+
lora_init="lora"
131+
lora_dropout=0
132+
fi
133+
105134

106135
# set to 1 for debug mode which only
107136
# uses 1600 training examples
@@ -161,13 +190,16 @@ SAVE=checkpoints/${dataset}/${DATE}/${exp_name}
161190

162191
rm -rf ${SAVE}; mkdir -p ${SAVE}
163192

164-
165193
rm checkpoints/hf_model/downloads/*.lock
194+
rm checkpoints/hf_model/*.lock
166195

167196
python -u examples/pytorch/summarization/run_summarization.py \
168197
--dataset_name 'xsum' \
169198
--model_name_or_path 'facebook/bart-large' \
170199
--cache_dir ${cache_dir} \
200+
--lora_alpha ${lora_alpha} \
201+
--lora_dropout ${lora_dropout} \
202+
--lora_init ${lora_init} \
171203
--attn_mode ${attn_mode} \
172204
--attn_option ${attn_option} \
173205
--attn_composition ${attn_composition} \

src/transformers/models/bart/modeling_bart.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ def __init__(
153153

154154
if config.attn_mode == "lora":
155155
self.q_proj = Linear(embed_dim, embed_dim, r=config.attn_bn, lora_alpha=config.lora_alpha,
156-
lora_dropout=config.lora_dropout)
156+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
157157
self.v_proj = Linear(embed_dim, embed_dim, r=config.attn_bn, lora_alpha=config.lora_alpha,
158-
lora_dropout=config.lora_dropout)
158+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
159159
else:
160160
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
161161
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

src/transformers/models/mbart/modeling_mbart.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ def __init__(
160160

161161
if config.attn_mode == "lora":
162162
self.q_proj = Linear(embed_dim, embed_dim, r=config.attn_bn, lora_alpha=config.lora_alpha,
163-
lora_dropout=config.lora_dropout)
163+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
164164
self.v_proj = Linear(embed_dim, embed_dim, r=config.attn_bn, lora_alpha=config.lora_alpha,
165-
lora_dropout=config.lora_dropout)
165+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
166166
else:
167167
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
168168
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@@ -404,9 +404,9 @@ def __init__(self, config: MBartConfig):
404404

405405
if config.ffn_mode == 'lora':
406406
self.fc1 = Linear(self.embed_dim, config.encoder_ffn_dim, r=config.ffn_bn, lora_alpha=config.lora_alpha,
407-
lora_dropout=config.lora_dropout)
407+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
408408
self.fc2 = Linear(config.encoder_ffn_dim, self.embed_dim, r=config.ffn_bn, lora_alpha=config.lora_alpha,
409-
lora_dropout=config.lora_dropout)
409+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
410410
else:
411411
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
412412
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
@@ -528,9 +528,9 @@ def __init__(self, config: MBartConfig):
528528
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
529529
if config.ffn_mode == 'lora':
530530
self.fc1 = Linear(self.embed_dim, config.decoder_ffn_dim, r=config.ffn_bn, lora_alpha=config.lora_alpha,
531-
lora_dropout=config.lora_dropout)
531+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
532532
self.fc2 = Linear(config.decoder_ffn_dim, self.embed_dim, r=config.ffn_bn, lora_alpha=config.lora_alpha,
533-
lora_dropout=config.lora_dropout)
533+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
534534
else:
535535
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
536536
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)

src/transformers/models/roberta/modeling_roberta.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,9 @@ def __init__(self, config, cache_key=None):
176176

177177
if config.attn_mode == "lora":
178178
self.query = Linear(config.hidden_size, self.all_head_size, r=config.attn_bn, lora_alpha=config.lora_alpha,
179-
lora_dropout=config.lora_dropout)
179+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
180180
self.value = Linear(config.hidden_size, self.all_head_size, r=config.attn_bn, lora_alpha=config.lora_alpha,
181-
lora_dropout=config.lora_dropout)
181+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
182182
else:
183183
self.query = nn.Linear(config.hidden_size, self.all_head_size)
184184
self.value = nn.Linear(config.hidden_size, self.all_head_size)
@@ -202,14 +202,14 @@ def __init__(self, config, cache_key=None):
202202
if self.config.attn_option == 'cross_attn' or self.config.attn_option == 'cross_attn_relu':
203203
self.ef_transform_layer_norm = nn.LayerNorm(config.hidden_size)
204204

205-
elif self.attn_mode == 'adapter':
206-
self.ef_attn_adapter = Adapter_Layer(self.config,
207-
dropout=self.dropout,
205+
elif self.attn_mode == 'adapter' and self.config.attn_option == 'parallel':
206+
self.ef_attn_adapter = Adapter_Layer(d_model=config.hidden_size,
207+
dropout=config.attention_probs_dropout_prob,
208208
bottleneck=self.config.attn_bn,
209209
adapter_layernorm_option="in",
210210
)
211-
elif self.attn_mode != 'none':
212-
raise ValueError("att_mode not supported")
211+
# elif self.attn_mode != 'none':
212+
# raise ValueError("att_mode not supported")
213213

214214
def transpose_for_scores(self, x):
215215
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
@@ -417,8 +417,8 @@ def __init__(self, config):
417417
self.config = config
418418

419419
if config.attn_mode == "adapter" and config.attn_option == "sequential":
420-
self.ef_attn_adapter = Adapter_Layer(self.config,
421-
dropout=self.dropout,
420+
self.ef_attn_adapter = Adapter_Layer(d_model=config.hidden_size,
421+
dropout=config.attention_probs_dropout_prob,
422422
bottleneck=self.config.attn_bn,
423423
adapter_layernorm_option="in",
424424
)
@@ -491,7 +491,7 @@ def __init__(self, config):
491491
super().__init__()
492492
if config.ffn_mode == 'lora':
493493
self.dense = Linear(config.hidden_size, config.intermediate_size, r=config.ffn_bn, lora_alpha=config.lora_alpha,
494-
lora_dropout=config.lora_dropout)
494+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
495495
else:
496496
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
497497
if isinstance(config.hidden_act, str):
@@ -511,7 +511,7 @@ def __init__(self, config):
511511
super().__init__()
512512
if config.ffn_mode == 'lora':
513513
self.dense = Linear(config.intermediate_size, config.hidden_size, r=config.ffn_bn, lora_alpha=config.lora_alpha,
514-
lora_dropout=config.lora_dropout)
514+
lora_dropout=config.lora_dropout, lora_init=config.lora_init)
515515
else:
516516
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
517517
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

0 commit comments

Comments
 (0)