Skip to content

Commit 01f0fd0

Browse files
authored
Fixes for LayoutLM (#7318)
1 parent 702a76f commit 01f0fd0

File tree

3 files changed

+31
-41
lines changed

3 files changed

+31
-41
lines changed

src/transformers/configuration_layoutlm.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -40,40 +40,40 @@ class LayoutLMConfig(BertConfig):
4040
4141
4242
Args:
43-
vocab_size (:obj:`int`, optional, defaults to 30522):
43+
vocab_size (:obj:`int`, `optional`, defaults to 30522):
4444
Vocabulary size of the LayoutLM model. Defines the different tokens that
4545
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.LayoutLMModel`.
46-
hidden_size (:obj:`int`, optional, defaults to 768):
46+
hidden_size (:obj:`int`, `optional`, defaults to 768):
4747
Dimensionality of the encoder layers and the pooler layer.
48-
num_hidden_layers (:obj:`int`, optional, defaults to 12):
48+
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
4949
Number of hidden layers in the Transformer encoder.
50-
num_attention_heads (:obj:`int`, optional, defaults to 12):
50+
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
5151
Number of attention heads for each attention layer in the Transformer encoder.
52-
intermediate_size (:obj:`int`, optional, defaults to 3072):
52+
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
5353
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
54-
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
54+
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
5555
The non-linear activation function (function or string) in the encoder and pooler.
56-
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
57-
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
56+
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
57+
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
5858
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
59-
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
59+
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
6060
The dropout ratio for the attention probabilities.
61-
max_position_embeddings (:obj:`int`, optional, defaults to 512):
61+
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
6262
The maximum sequence length that this model might ever be used with.
6363
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
64-
type_vocab_size (:obj:`int`, optional, defaults to 2):
65-
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
66-
initializer_range (:obj:`float`, optional, defaults to 0.02):
64+
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
65+
The vocabulary size of the :obj:`token_type_ids` passed into :class:`~transformers.LayoutLMModel`.
66+
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
6767
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68-
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
68+
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
6969
The epsilon used by the layer normalization layers.
70-
gradient_checkpointing (:obj:`bool`, optional, defaults to :obj:`False`):
70+
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
7171
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
72-
max_2d_position_embeddings (:obj:`int`, optional, defaults to 1024):
72+
max_2d_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
7373
The maximum value that the 2D position embedding might ever used.
7474
Typically set this to something large just in case (e.g., 1024).
7575
76-
Example::
76+
Examples::
7777
7878
>>> from transformers import LayoutLMModel, LayoutLMConfig
7979

src/transformers/modeling_layoutlm.py

+12-22
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def forward(
118118
return embeddings
119119

120120

121+
# Copied from transformers.modeling_bert.BertSelfAttention with Bert->LayoutLM
121122
class LayoutLMSelfAttention(nn.Module):
122123
def __init__(self, config):
123124
super().__init__()
@@ -172,6 +173,7 @@ def forward(
172173
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
173174
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
174175
if attention_mask is not None:
176+
# Apply the attention mask is (precomputed for all layers in LayoutLMModel forward() function)
175177
attention_scores = attention_scores + attention_mask
176178

177179
# Normalize the attention scores to probabilities.
@@ -195,6 +197,7 @@ def forward(
195197
return outputs
196198

197199

200+
# Copied from transformers.modeling_bert.BertSelfOutput with Bert->LayoutLM
198201
class LayoutLMSelfOutput(nn.Module):
199202
def __init__(self, config):
200203
super().__init__()
@@ -209,6 +212,7 @@ def forward(self, hidden_states, input_tensor):
209212
return hidden_states
210213

211214

215+
# Copied from transformers.modeling_bert.BertAttention with Bert->LayoutLM
212216
class LayoutLMAttention(nn.Module):
213217
def __init__(self, config):
214218
super().__init__()
@@ -256,6 +260,7 @@ def forward(
256260
return outputs
257261

258262

263+
# Copied from transformers.modeling_bert.BertIntermediate
259264
class LayoutLMIntermediate(nn.Module):
260265
def __init__(self, config):
261266
super().__init__()
@@ -271,6 +276,7 @@ def forward(self, hidden_states):
271276
return hidden_states
272277

273278

279+
# Copied from transformers.modeling_bert.BertOutput with Bert->LayoutLM
274280
class LayoutLMOutput(nn.Module):
275281
def __init__(self, config):
276282
super().__init__()
@@ -285,6 +291,7 @@ def forward(self, hidden_states, input_tensor):
285291
return hidden_states
286292

287293

294+
# Copied from transformers.modeling_bert.BertLayer with Bert->LayoutLM
288295
class LayoutLMLayer(nn.Module):
289296
def __init__(self, config):
290297
super().__init__()
@@ -344,6 +351,7 @@ def feed_forward_chunk(self, attention_output):
344351
return layer_output
345352

346353

354+
# Copied from transformers.modeling_bert.BertEncoder with Bert->LayoutLM
347355
class LayoutLMEncoder(nn.Module):
348356
def __init__(self, config):
349357
super().__init__()
@@ -408,6 +416,7 @@ def custom_forward(*inputs):
408416
)
409417

410418

419+
# Copied from transformers.modeling_bert.BertPooler
411420
class LayoutLMPooler(nn.Module):
412421
def __init__(self, config):
413422
super().__init__()
@@ -423,6 +432,7 @@ def forward(self, hidden_states):
423432
return pooled_output
424433

425434

435+
# Copied from transformers.modeling_bert.BertPredictionHeadTransform with Bert->LayoutLM
426436
class LayoutLMPredictionHeadTransform(nn.Module):
427437
def __init__(self, config):
428438
super().__init__()
@@ -440,6 +450,7 @@ def forward(self, hidden_states):
440450
return hidden_states
441451

442452

453+
# Copied from transformers.modeling_bert.BertLMPredictionHead with Bert->LayoutLM
443454
class LayoutLMLMPredictionHead(nn.Module):
444455
def __init__(self, config):
445456
super().__init__()
@@ -460,6 +471,7 @@ def forward(self, hidden_states):
460471
return hidden_states
461472

462473

474+
# Copied from transformers.modeling_bert.BertOnlyMLMHead with Bert->LayoutLM
463475
class LayoutLMOnlyMLMHead(nn.Module):
464476
def __init__(self, config):
465477
super().__init__()
@@ -470,28 +482,6 @@ def forward(self, sequence_output):
470482
return prediction_scores
471483

472484

473-
class LayoutLMOnlyNSPHead(nn.Module):
474-
def __init__(self, config):
475-
super().__init__()
476-
self.seq_relationship = nn.Linear(config.hidden_size, 2)
477-
478-
def forward(self, pooled_output):
479-
seq_relationship_score = self.seq_relationship(pooled_output)
480-
return seq_relationship_score
481-
482-
483-
class LayoutLMPreTrainingHeads(nn.Module):
484-
def __init__(self, config):
485-
super().__init__()
486-
self.predictions = LayoutLMLMPredictionHead(config)
487-
self.seq_relationship = nn.Linear(config.hidden_size, 2)
488-
489-
def forward(self, sequence_output, pooled_output):
490-
prediction_scores = self.predictions(sequence_output)
491-
seq_relationship_score = self.seq_relationship(pooled_output)
492-
return prediction_scores, seq_relationship_score
493-
494-
495485
class LayoutLMPreTrainedModel(PreTrainedModel):
496486
"""An abstract class to handle weights initialization and
497487
a simple interface for downloading and loading pretrained models.

src/transformers/modeling_roberta.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
142142
return position_ids.unsqueeze(0).expand(input_shape)
143143

144144

145-
# Copied from transformers.modeling_bert.BertSelfAttention
145+
# Copied from transformers.modeling_bert.BertSelfAttention with Bert->Roberta
146146
class RobertaSelfAttention(nn.Module):
147147
def __init__(self, config):
148148
super().__init__()
@@ -197,7 +197,7 @@ def forward(
197197
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
198198
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
199199
if attention_mask is not None:
200-
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
200+
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
201201
attention_scores = attention_scores + attention_mask
202202

203203
# Normalize the attention scores to probabilities.

0 commit comments

Comments
 (0)