Skip to content

Commit 975a033

Browse files
authored
Update pytorch-partial-tagger (#41)
1 parent fb9b26b commit 975a033

File tree

2 files changed

+101
-12
lines changed

2 files changed

+101
-12
lines changed

spacy_partial_tagger/pipeline.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import srsly
44
import torch
5-
from partial_tagger.training import compute_partially_supervised_loss, create_tag_bitmap
5+
from partial_tagger.crf import functional as F
6+
from partial_tagger.training import create_tag_bitmap
67
from sequence_label import LabelSet, SequenceLabel
78
from spacy import util
89
from spacy.errors import Errors
@@ -18,6 +19,40 @@
1819
from thinc.types import Floats2d, Floats4d
1920

2021

22+
def compute_partially_supervised_loss(
23+
log_potentials: torch.Tensor,
24+
tag_bitmap: torch.Tensor,
25+
mask: torch.Tensor,
26+
outside_index: int,
27+
target_entity_ratio: float = 0.15,
28+
entity_ratio_margin: float = 0.05,
29+
balancing_coefficient: int = 10,
30+
) -> torch.Tensor:
31+
with torch.enable_grad():
32+
# log partition
33+
log_Z = F.forward_algorithm(log_potentials)
34+
35+
# marginal probabilities
36+
p = torch.autograd.grad(log_Z.sum(), log_potentials, create_graph=True)[0].sum(
37+
dim=-1
38+
)
39+
p *= mask[..., None]
40+
41+
expected_entity_count = (
42+
p[:, :, :outside_index].sum() + p[:, :, outside_index + 1 :].sum()
43+
)
44+
expected_entity_ratio = expected_entity_count / p.sum()
45+
expected_entity_ratio_loss = torch.clamp(
46+
(expected_entity_ratio - target_entity_ratio).abs() - entity_ratio_margin,
47+
min=0,
48+
)
49+
50+
score = F.multitag_sequence_score(log_potentials, tag_bitmap, mask)
51+
supervised_loss = (log_Z - score).mean()
52+
53+
return supervised_loss + balancing_coefficient * expected_entity_ratio_loss
54+
55+
2156
class PartialEntityRecognizer(TrainablePipe):
2257
def __init__(
2358
self,

spacy_partial_tagger/util.py

+65-11
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,80 @@
1-
from typing import List, Tuple
1+
from typing import Dict, List, Tuple, cast
22

33
import spacy_alignments as tokenizations
4-
from partial_tagger.decoders.viterbi import Constrainer, ViterbiDecoder
4+
import torch
5+
from partial_tagger.crf import functional as F
6+
from partial_tagger.crf.nn import CRF
7+
from partial_tagger.encoders.base import BaseEncoder
58
from partial_tagger.encoders.transformer import TransformerModelEncoderFactory
6-
from partial_tagger.tagger import SequenceTagger
79
from sequence_label import LabelSet
10+
from torch import nn
811
from transformers import PreTrainedTokenizer
912

1013

14+
class SequenceTagger(nn.Module):
15+
def __init__(
16+
self,
17+
encoder: BaseEncoder,
18+
padding_index: int,
19+
start_states: Tuple[bool, ...],
20+
end_states: Tuple[bool, ...],
21+
transitions: Tuple[Tuple[bool, ...], ...],
22+
):
23+
super().__init__()
24+
25+
self.encoder = encoder
26+
self.crf = CRF(encoder.get_hidden_size())
27+
self.start_constraints = nn.Parameter(
28+
torch.tensor(start_states), requires_grad=False
29+
)
30+
self.end_constraints = nn.Parameter(
31+
torch.tensor(end_states), requires_grad=False
32+
)
33+
self.transition_constraints = nn.Parameter(
34+
torch.tensor(transitions), requires_grad=False
35+
)
36+
self.padding_index = padding_index
37+
38+
def __constrain(
39+
self, log_potentials: torch.Tensor, mask: torch.Tensor
40+
) -> torch.Tensor:
41+
return F.constrain_log_potentials(
42+
log_potentials,
43+
mask,
44+
self.start_constraints,
45+
self.end_constraints,
46+
self.transition_constraints,
47+
)
48+
49+
def forward(
50+
self, inputs: Dict[str, torch.Tensor], mask: torch.Tensor
51+
) -> Tuple[torch.Tensor, torch.Tensor]:
52+
log_potentials = self.crf(self.encoder(inputs), mask)
53+
54+
contrained = self.__constrain(log_potentials, mask)
55+
56+
contrained.requires_grad_()
57+
58+
with torch.enable_grad():
59+
_, tag_indices = F.decode(contrained)
60+
61+
return log_potentials, tag_indices * mask + self.padding_index * (~mask)
62+
63+
def predict(
64+
self, inputs: Dict[str, torch.Tensor], mask: torch.Tensor
65+
) -> torch.Tensor:
66+
return cast(torch.Tensor, self(inputs, mask)[1])
67+
68+
1169
def create_tagger(
1270
model_name: str, label_set: LabelSet, padding_index: int
1371
) -> SequenceTagger:
1472
return SequenceTagger(
1573
TransformerModelEncoderFactory(model_name).create(label_set),
16-
ViterbiDecoder(
17-
padding_index,
18-
Constrainer(
19-
label_set.start_states,
20-
label_set.end_states,
21-
label_set.transitions,
22-
),
23-
),
74+
padding_index,
75+
label_set.start_states,
76+
label_set.end_states,
77+
label_set.transitions,
2478
)
2579

2680

0 commit comments

Comments
 (0)