1
+ # Standalone version from allennlp
2
+ import torch
3
+ import math
4
+ from typing import Any , Dict , List , Optional , Sequence , Tuple , TypeVar , Union , NamedTuple
5
+
6
+ def viterbi_decode (
7
+ tag_sequence : torch .Tensor ,
8
+ transition_matrix : torch .Tensor ,
9
+ tag_observations : Optional [List [int ]] = None ,
10
+ allowed_start_transitions : torch .Tensor = None ,
11
+ allowed_end_transitions : torch .Tensor = None ,
12
+ top_k : int = None ,
13
+ ):
14
+ """
15
+ Perform Viterbi decoding in log space over a sequence given a transition matrix
16
+ specifying pairwise (transition) potentials between tags and a matrix of shape
17
+ (sequence_length, num_tags) specifying unary potentials for possible tags per
18
+ timestep.
19
+ # Parameters
20
+ tag_sequence : `torch.Tensor`, required.
21
+ A tensor of shape (sequence_length, num_tags) representing scores for
22
+ a set of tags over a given sequence.
23
+ transition_matrix : `torch.Tensor`, required.
24
+ A tensor of shape (num_tags, num_tags) representing the binary potentials
25
+ for transitioning between a given pair of tags.
26
+ tag_observations : `Optional[List[int]]`, optional, (default = `None`)
27
+ A list of length `sequence_length` containing the class ids of observed
28
+ elements in the sequence, with unobserved elements being set to -1. Note that
29
+ it is possible to provide evidence which results in degenerate labelings if
30
+ the sequences of tags you provide as evidence cannot transition between each
31
+ other, or those transitions are extremely unlikely. In this situation we log a
32
+ warning, but the responsibility for providing self-consistent evidence ultimately
33
+ lies with the user.
34
+ allowed_start_transitions : `torch.Tensor`, optional, (default = `None`)
35
+ An optional tensor of shape (num_tags,) describing which tags the START token
36
+ may transition *to*. If provided, additional transition constraints will be used for
37
+ determining the start element of the sequence.
38
+ allowed_end_transitions : `torch.Tensor`, optional, (default = `None`)
39
+ An optional tensor of shape (num_tags,) describing which tags may transition *to* the
40
+ end tag. If provided, additional transition constraints will be used for determining
41
+ the end element of the sequence.
42
+ top_k : `int`, optional, (default = `None`)
43
+ Optional integer specifying how many of the top paths to return. For top_k>=1, returns
44
+ a tuple of two lists: top_k_paths, top_k_scores, For top_k==None, returns a flattened
45
+ tuple with just the top path and its score (not in lists, for backwards compatibility).
46
+ # Returns
47
+ viterbi_path : `List[int]`
48
+ The tag indices of the maximum likelihood tag sequence.
49
+ viterbi_score : `torch.Tensor`
50
+ The score of the viterbi path.
51
+ """
52
+ if top_k is None :
53
+ top_k = 1
54
+ flatten_output = True
55
+ elif top_k >= 1 :
56
+ flatten_output = False
57
+ else :
58
+ raise ValueError (f"top_k must be either None or an integer >=1. Instead received { top_k } " )
59
+
60
+ sequence_length , num_tags = list (tag_sequence .size ())
61
+
62
+ has_start_end_restrictions = (
63
+ allowed_end_transitions is not None or allowed_start_transitions is not None
64
+ )
65
+
66
+ if has_start_end_restrictions :
67
+
68
+ if allowed_end_transitions is None :
69
+ allowed_end_transitions = torch .zeros (num_tags )
70
+ if allowed_start_transitions is None :
71
+ allowed_start_transitions = torch .zeros (num_tags )
72
+
73
+ num_tags = num_tags + 2
74
+ new_transition_matrix = torch .zeros (num_tags , num_tags )
75
+ new_transition_matrix [:- 2 , :- 2 ] = transition_matrix
76
+
77
+ # Start and end transitions are fully defined, but cannot transition between each other.
78
+
79
+ allowed_start_transitions = torch .cat (
80
+ [allowed_start_transitions , torch .tensor ([- math .inf , - math .inf ])]
81
+ )
82
+ allowed_end_transitions = torch .cat (
83
+ [allowed_end_transitions , torch .tensor ([- math .inf , - math .inf ])]
84
+ )
85
+
86
+ # First define how we may transition FROM the start and end tags.
87
+ new_transition_matrix [- 2 , :] = allowed_start_transitions
88
+ # We cannot transition from the end tag to any tag.
89
+ new_transition_matrix [- 1 , :] = - math .inf
90
+
91
+ new_transition_matrix [:, - 1 ] = allowed_end_transitions
92
+ # We cannot transition to the start tag from any tag.
93
+ new_transition_matrix [:, - 2 ] = - math .inf
94
+
95
+ transition_matrix = new_transition_matrix
96
+
97
+ if tag_observations :
98
+ if len (tag_observations ) != sequence_length :
99
+ raise ConfigurationError (
100
+ "Observations were provided, but they were not the same length "
101
+ "as the sequence. Found sequence of length: {} and evidence: {}" .format (
102
+ sequence_length , tag_observations
103
+ )
104
+ )
105
+ else :
106
+ tag_observations = [- 1 for _ in range (sequence_length )]
107
+
108
+ if has_start_end_restrictions :
109
+ tag_observations = [num_tags - 2 ] + tag_observations + [num_tags - 1 ]
110
+ zero_sentinel = torch .zeros (1 , num_tags )
111
+ extra_tags_sentinel = torch .ones (sequence_length , 2 ) * - math .inf
112
+ tag_sequence = torch .cat ([tag_sequence , extra_tags_sentinel ], - 1 )
113
+ tag_sequence = torch .cat ([zero_sentinel , tag_sequence , zero_sentinel ], 0 )
114
+ sequence_length = tag_sequence .size (0 )
115
+
116
+ path_scores = []
117
+ path_indices = []
118
+
119
+ if tag_observations [0 ] != - 1 :
120
+ one_hot = torch .zeros (num_tags )
121
+ one_hot [tag_observations [0 ]] = 100000.0
122
+ path_scores .append (one_hot .unsqueeze (0 ))
123
+ else :
124
+ path_scores .append (tag_sequence [0 , :].unsqueeze (0 ))
125
+
126
+ # Evaluate the scores for all possible paths.
127
+ for timestep in range (1 , sequence_length ):
128
+ # Add pairwise potentials to current scores.
129
+ summed_potentials = path_scores [timestep - 1 ].unsqueeze (2 ) + transition_matrix
130
+ summed_potentials = summed_potentials .view (- 1 , num_tags )
131
+
132
+ # Best pairwise potential path score from the previous timestep.
133
+ max_k = min (summed_potentials .size ()[0 ], top_k )
134
+ scores , paths = torch .topk (summed_potentials , k = max_k , dim = 0 )
135
+
136
+ # If we have an observation for this timestep, use it
137
+ # instead of the distribution over tags.
138
+ observation = tag_observations [timestep ]
139
+ # Warn the user if they have passed
140
+ # invalid/extremely unlikely evidence.
141
+ if tag_observations [timestep - 1 ] != - 1 and observation != - 1 :
142
+ if transition_matrix [tag_observations [timestep - 1 ], observation ] < - 10000 :
143
+ logger .warning (
144
+ "The pairwise potential between tags you have passed as "
145
+ "observations is extremely unlikely. Double check your evidence "
146
+ "or transition potentials!"
147
+ )
148
+ if observation != - 1 :
149
+ one_hot = torch .zeros (num_tags )
150
+ one_hot [observation ] = 100000.0
151
+ path_scores .append (one_hot .unsqueeze (0 ))
152
+ else :
153
+ path_scores .append (tag_sequence [timestep , :] + scores )
154
+ path_indices .append (paths .squeeze ())
155
+
156
+ # Construct the most likely sequence backwards.
157
+ path_scores_v = path_scores [- 1 ].view (- 1 )
158
+ max_k = min (path_scores_v .size ()[0 ], top_k )
159
+ viterbi_scores , best_paths = torch .topk (path_scores_v , k = max_k , dim = 0 )
160
+ viterbi_paths = []
161
+ for i in range (max_k ):
162
+ viterbi_path = [best_paths [i ]]
163
+ for backward_timestep in reversed (path_indices ):
164
+ viterbi_path .append (int (backward_timestep .view (- 1 )[viterbi_path [- 1 ]]))
165
+ # Reverse the backward path.
166
+ viterbi_path .reverse ()
167
+
168
+ if has_start_end_restrictions :
169
+ viterbi_path = viterbi_path [1 :- 1 ]
170
+
171
+ # Viterbi paths uses (num_tags * n_permutations) nodes; therefore, we need to modulo.
172
+ viterbi_path = [j % num_tags for j in viterbi_path ]
173
+ viterbi_paths .append (viterbi_path )
174
+
175
+ if flatten_output :
176
+ return viterbi_paths [0 ], viterbi_scores [0 ]
177
+
178
+ return viterbi_paths , viterbi_scores
0 commit comments