diff --git a/tapas/experiments/tapas_pretraining_experiment.py b/tapas/experiments/tapas_pretraining_experiment.py index a7b4307..dba11c2 100644 --- a/tapas/experiments/tapas_pretraining_experiment.py +++ b/tapas/experiments/tapas_pretraining_experiment.py @@ -27,6 +27,18 @@ FLAGS = flags.FLAGS +flags.DEFINE_integer( + "restrict_attention_bucket_size", 0, "For sparse attention modes, further " + "restrict attention to consecutive buckets of uniform size.") + +flags.DEFINE_integer( + "restrict_attention_header_size", None, "For sparse attention modes, size " + "of the first section that will attend to/from everything else.") + +flags.DEFINE_float( + "restrict_attention_row_heads_ratio", 0.5, "For sparse attention modes, " + "proportion of heads that should focus on rows vs columns.") + flags.DEFINE_string("data_format", "tfrecord", "The input data format.") flags.DEFINE_string(