Skip to content

Commit 147fc41

Browse files
committed
standalone viterbi to avoid dependency conflict
1 parent 9d31557 commit 147fc41

File tree

3 files changed

+185
-7
lines changed

3 files changed

+185
-7
lines changed

loss/crf.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88

99
import torch
1010

11-
from allennlp.common.checks import ConfigurationError
12-
import allennlp.nn.util as util
11+
#from allennlp.common.checks import ConfigurationError
12+
#import allennlp.nn.util as util
1313
from util.util import *
14+
from util.viterbi import *
1415

1516

1617
def allowed_transitions(constraint_type: str, labels: Dict[int, str]) -> List[Tuple[int, int]]:
@@ -155,7 +156,7 @@ def is_transition_allowed(
155156
]
156157
)
157158
else:
158-
raise ConfigurationError(f"Unknown constraint type: {constraint_type}")
159+
raise Exception(f"Unknown constraint type: {constraint_type}")
159160

160161

161162
class ConditionalRandomField(torch.nn.Module):
@@ -259,7 +260,7 @@ def _input_likelihood(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.T
259260

260261
# In valid positions (mask == 1) we want to take the logsumexp over the current_tag dimension
261262
# of ``inner``. Otherwise (mask == 0) we want to retain the previous alpha.
262-
alpha = util.logsumexp(inner, 1) * mask[i].view(batch_size, 1) + alpha * (
263+
alpha = torch.logsumexp(inner, 1) * mask[i].view(batch_size, 1) + alpha * (
263264
1 - mask[i]
264265
).view(batch_size, 1)
265266

@@ -270,7 +271,7 @@ def _input_likelihood(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.T
270271
stops = alpha
271272

272273
# Finally we log_sum_exp along the num_tags dim, result is (batch_size,)
273-
return util.logsumexp(stops)
274+
return torch.logsumexp(stops)
274275

275276
def _joint_likelihood(
276277
self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.LongTensor
@@ -408,7 +409,7 @@ def viterbi_tags(
408409
tag_sequence[sequence_length + 1, end_tag] = 0.0
409410

410411
# We pass the tags and the transitions to ``viterbi_decode``.
411-
viterbi_path, viterbi_score = util.viterbi_decode(
412+
viterbi_path, viterbi_score = viterbi_decode(
412413
tag_sequence[: (sequence_length + 2)], _transitions
413414
)
414415
# Get rid of START and END sentinels and append.

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
spacy==2.3.2
22
h5py==2.10.0
3-
allennlp==1.1.0
43
transformers==3.4.0
54
torch==1.6.0
65
numpy==1.18.5

util/viterbi.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Standalone version from allennlp
2+
import torch
3+
import math
4+
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, NamedTuple
5+
6+
def viterbi_decode(
7+
tag_sequence: torch.Tensor,
8+
transition_matrix: torch.Tensor,
9+
tag_observations: Optional[List[int]] = None,
10+
allowed_start_transitions: torch.Tensor = None,
11+
allowed_end_transitions: torch.Tensor = None,
12+
top_k: int = None,
13+
):
14+
"""
15+
Perform Viterbi decoding in log space over a sequence given a transition matrix
16+
specifying pairwise (transition) potentials between tags and a matrix of shape
17+
(sequence_length, num_tags) specifying unary potentials for possible tags per
18+
timestep.
19+
# Parameters
20+
tag_sequence : `torch.Tensor`, required.
21+
A tensor of shape (sequence_length, num_tags) representing scores for
22+
a set of tags over a given sequence.
23+
transition_matrix : `torch.Tensor`, required.
24+
A tensor of shape (num_tags, num_tags) representing the binary potentials
25+
for transitioning between a given pair of tags.
26+
tag_observations : `Optional[List[int]]`, optional, (default = `None`)
27+
A list of length `sequence_length` containing the class ids of observed
28+
elements in the sequence, with unobserved elements being set to -1. Note that
29+
it is possible to provide evidence which results in degenerate labelings if
30+
the sequences of tags you provide as evidence cannot transition between each
31+
other, or those transitions are extremely unlikely. In this situation we log a
32+
warning, but the responsibility for providing self-consistent evidence ultimately
33+
lies with the user.
34+
allowed_start_transitions : `torch.Tensor`, optional, (default = `None`)
35+
An optional tensor of shape (num_tags,) describing which tags the START token
36+
may transition *to*. If provided, additional transition constraints will be used for
37+
determining the start element of the sequence.
38+
allowed_end_transitions : `torch.Tensor`, optional, (default = `None`)
39+
An optional tensor of shape (num_tags,) describing which tags may transition *to* the
40+
end tag. If provided, additional transition constraints will be used for determining
41+
the end element of the sequence.
42+
top_k : `int`, optional, (default = `None`)
43+
Optional integer specifying how many of the top paths to return. For top_k>=1, returns
44+
a tuple of two lists: top_k_paths, top_k_scores, For top_k==None, returns a flattened
45+
tuple with just the top path and its score (not in lists, for backwards compatibility).
46+
# Returns
47+
viterbi_path : `List[int]`
48+
The tag indices of the maximum likelihood tag sequence.
49+
viterbi_score : `torch.Tensor`
50+
The score of the viterbi path.
51+
"""
52+
if top_k is None:
53+
top_k = 1
54+
flatten_output = True
55+
elif top_k >= 1:
56+
flatten_output = False
57+
else:
58+
raise ValueError(f"top_k must be either None or an integer >=1. Instead received {top_k}")
59+
60+
sequence_length, num_tags = list(tag_sequence.size())
61+
62+
has_start_end_restrictions = (
63+
allowed_end_transitions is not None or allowed_start_transitions is not None
64+
)
65+
66+
if has_start_end_restrictions:
67+
68+
if allowed_end_transitions is None:
69+
allowed_end_transitions = torch.zeros(num_tags)
70+
if allowed_start_transitions is None:
71+
allowed_start_transitions = torch.zeros(num_tags)
72+
73+
num_tags = num_tags + 2
74+
new_transition_matrix = torch.zeros(num_tags, num_tags)
75+
new_transition_matrix[:-2, :-2] = transition_matrix
76+
77+
# Start and end transitions are fully defined, but cannot transition between each other.
78+
79+
allowed_start_transitions = torch.cat(
80+
[allowed_start_transitions, torch.tensor([-math.inf, -math.inf])]
81+
)
82+
allowed_end_transitions = torch.cat(
83+
[allowed_end_transitions, torch.tensor([-math.inf, -math.inf])]
84+
)
85+
86+
# First define how we may transition FROM the start and end tags.
87+
new_transition_matrix[-2, :] = allowed_start_transitions
88+
# We cannot transition from the end tag to any tag.
89+
new_transition_matrix[-1, :] = -math.inf
90+
91+
new_transition_matrix[:, -1] = allowed_end_transitions
92+
# We cannot transition to the start tag from any tag.
93+
new_transition_matrix[:, -2] = -math.inf
94+
95+
transition_matrix = new_transition_matrix
96+
97+
if tag_observations:
98+
if len(tag_observations) != sequence_length:
99+
raise ConfigurationError(
100+
"Observations were provided, but they were not the same length "
101+
"as the sequence. Found sequence of length: {} and evidence: {}".format(
102+
sequence_length, tag_observations
103+
)
104+
)
105+
else:
106+
tag_observations = [-1 for _ in range(sequence_length)]
107+
108+
if has_start_end_restrictions:
109+
tag_observations = [num_tags - 2] + tag_observations + [num_tags - 1]
110+
zero_sentinel = torch.zeros(1, num_tags)
111+
extra_tags_sentinel = torch.ones(sequence_length, 2) * -math.inf
112+
tag_sequence = torch.cat([tag_sequence, extra_tags_sentinel], -1)
113+
tag_sequence = torch.cat([zero_sentinel, tag_sequence, zero_sentinel], 0)
114+
sequence_length = tag_sequence.size(0)
115+
116+
path_scores = []
117+
path_indices = []
118+
119+
if tag_observations[0] != -1:
120+
one_hot = torch.zeros(num_tags)
121+
one_hot[tag_observations[0]] = 100000.0
122+
path_scores.append(one_hot.unsqueeze(0))
123+
else:
124+
path_scores.append(tag_sequence[0, :].unsqueeze(0))
125+
126+
# Evaluate the scores for all possible paths.
127+
for timestep in range(1, sequence_length):
128+
# Add pairwise potentials to current scores.
129+
summed_potentials = path_scores[timestep - 1].unsqueeze(2) + transition_matrix
130+
summed_potentials = summed_potentials.view(-1, num_tags)
131+
132+
# Best pairwise potential path score from the previous timestep.
133+
max_k = min(summed_potentials.size()[0], top_k)
134+
scores, paths = torch.topk(summed_potentials, k=max_k, dim=0)
135+
136+
# If we have an observation for this timestep, use it
137+
# instead of the distribution over tags.
138+
observation = tag_observations[timestep]
139+
# Warn the user if they have passed
140+
# invalid/extremely unlikely evidence.
141+
if tag_observations[timestep - 1] != -1 and observation != -1:
142+
if transition_matrix[tag_observations[timestep - 1], observation] < -10000:
143+
logger.warning(
144+
"The pairwise potential between tags you have passed as "
145+
"observations is extremely unlikely. Double check your evidence "
146+
"or transition potentials!"
147+
)
148+
if observation != -1:
149+
one_hot = torch.zeros(num_tags)
150+
one_hot[observation] = 100000.0
151+
path_scores.append(one_hot.unsqueeze(0))
152+
else:
153+
path_scores.append(tag_sequence[timestep, :] + scores)
154+
path_indices.append(paths.squeeze())
155+
156+
# Construct the most likely sequence backwards.
157+
path_scores_v = path_scores[-1].view(-1)
158+
max_k = min(path_scores_v.size()[0], top_k)
159+
viterbi_scores, best_paths = torch.topk(path_scores_v, k=max_k, dim=0)
160+
viterbi_paths = []
161+
for i in range(max_k):
162+
viterbi_path = [best_paths[i]]
163+
for backward_timestep in reversed(path_indices):
164+
viterbi_path.append(int(backward_timestep.view(-1)[viterbi_path[-1]]))
165+
# Reverse the backward path.
166+
viterbi_path.reverse()
167+
168+
if has_start_end_restrictions:
169+
viterbi_path = viterbi_path[1:-1]
170+
171+
# Viterbi paths uses (num_tags * n_permutations) nodes; therefore, we need to modulo.
172+
viterbi_path = [j % num_tags for j in viterbi_path]
173+
viterbi_paths.append(viterbi_path)
174+
175+
if flatten_output:
176+
return viterbi_paths[0], viterbi_scores[0]
177+
178+
return viterbi_paths, viterbi_scores

0 commit comments

Comments
 (0)