1
- from typing import Callable , Dict , Iterable , List , Optional , Set , Tuple , cast
1
+ from typing import Callable , Dict , Iterable , List , Optional , Tuple , cast
2
2
3
3
import srsly
4
4
import torch
5
- from partial_tagger .data import Alignments , LabelSet
6
- from partial_tagger .training import compute_partially_supervised_loss
7
- from partial_tagger .utils import create_tag
5
+ from partial_tagger .training import compute_partially_supervised_loss , create_tag_bitmap
6
+ from sequence_label import LabelSet , SequenceLabel
8
7
from spacy import util
9
8
from spacy .errors import Errors
10
9
from spacy .language import Language
@@ -39,7 +38,9 @@ def __init__(
39
38
40
39
@property
41
40
def label_set (self ) -> LabelSet :
42
- return LabelSet (set (self .cfg ["labels" ]))
41
+ return LabelSet (
42
+ labels = set (self .cfg ["labels" ]), padding_index = self .padding_index
43
+ )
43
44
44
45
def predict (self , docs : List [Doc ]) -> Floats2d :
45
46
(_ , tag_indices ) = self .model .predict (docs )
@@ -50,16 +51,14 @@ def set_annotations(
50
51
docs : List [Doc ],
51
52
tag_indices : Floats2d ,
52
53
) -> None :
53
- alignments = Alignments (tuple (doc .user_data ["alignment" ] for doc in docs ))
54
- tags_batch = alignments .create_char_based_tags (
55
- tag_indices .tolist (),
56
- label_set = self .label_set ,
57
- padding_index = self .padding_index ,
54
+ labels = self .label_set .decode (
55
+ tag_indices = tag_indices .tolist (),
56
+ alignments = tuple (doc .user_data ["alignment" ] for doc in docs ),
58
57
)
59
58
60
- for doc , tags in zip (docs , tags_batch ):
59
+ for doc , label in zip (docs , labels ):
61
60
ents = []
62
- for tag in tags :
61
+ for tag in label . tags :
63
62
span = doc .char_span (tag .start , tag .start + tag .length , tag .label )
64
63
if span :
65
64
ents .append (span )
@@ -89,19 +88,17 @@ def update(
89
88
losses [self .name ] += loss
90
89
return losses
91
90
92
- def initialize (
93
- self , get_examples : Callable , * , nlp : Language , labels : Optional [dict ] = None
94
- ) -> None :
91
+ def initialize (self , get_examples : Callable , * , nlp : Language ) -> None :
95
92
X_small : List [Doc ] = []
96
- label : Set [str ] = set ()
93
+ labels : List [str ] = []
97
94
for example in get_examples ():
98
95
if len (X_small ) < 10 :
99
96
X_small .append (example .x )
100
97
for entity in example .y .ents :
101
- if entity .label_ not in label :
102
- label . add (entity .label_ )
98
+ if entity .label_ not in labels :
99
+ labels . append (entity .label_ )
103
100
104
- self .cfg ["labels" ] = list ( label )
101
+ self .cfg ["labels" ] = labels
105
102
106
103
self .model .initialize (
107
104
X = X_small ,
@@ -113,23 +110,32 @@ def get_loss(
113
110
) -> Tuple [float , Floats4d ]:
114
111
scores_pt = xp2torch (scores , requires_grad = True )
115
112
116
- char_based_tags = []
117
- temp = []
113
+ labels = []
114
+ alignments = []
118
115
lengths = []
119
116
for example in examples :
120
- tags = tuple (
121
- create_tag (ent .start_char , len (ent .text ), ent .label_ )
122
- for ent in example .y .ents
117
+ labels .append (
118
+ SequenceLabel .from_dict (
119
+ tags = [
120
+ {
121
+ "start" : ent .start_char ,
122
+ "end" : ent .end_char ,
123
+ "label" : ent .label_ ,
124
+ }
125
+ for ent in example .y .ents
126
+ ],
127
+ size = len (example .y .text ),
128
+ )
123
129
)
124
- char_based_tags .append (tags )
125
130
126
131
alignment = example .x .user_data ["alignment" ]
127
- lengths .append (alignment . num_tokens )
128
- temp .append (alignment )
132
+ alignments .append (alignment )
133
+ lengths .append (alignment . target_size )
129
134
130
- alignments = Alignments (tuple (temp ))
131
- tag_bitmap = torch .tensor (
132
- alignments .get_tag_bitmap (char_based_tags , self .label_set ),
135
+ tag_bitmap = create_tag_bitmap (
136
+ label_set = self .label_set ,
137
+ labels = tuple (labels ),
138
+ alignments = tuple (alignments ),
133
139
device = scores_pt .device ,
134
140
)
135
141
@@ -140,7 +146,7 @@ def get_loss(
140
146
)
141
147
142
148
loss = compute_partially_supervised_loss (
143
- scores_pt , tag_bitmap , mask , self .label_set .get_outside_index ()
149
+ scores_pt , tag_bitmap , mask , self .label_set .outside_index
144
150
)
145
151
146
152
(grad ,) = torch .autograd .grad (loss , scores_pt )
0 commit comments