Skip to content

Commit fb9b26b

Browse files
authored
Dependency Update (#40)
* Update dependencies * Update a pytest setting * Use sequence-label
1 parent d6e470f commit fb9b26b

File tree

7 files changed

+62
-60
lines changed

7 files changed

+62
-60
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ jobs:
3434
mypy tests
3535
- name: Run tests
3636
run: |
37-
pytest --cov=spacy_partial_tagger --cov-report=term-missing
37+
pytest

pyproject.toml

+4-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
55
[project]
66
name = "spacy-partial-tagger"
77
description = "Sequence Tagger for Partially Annotated Dataset in spaCy"
8-
requires-python = ">=3.8,<4.0"
8+
requires-python = ">=3.8,<3.12"
99
readme = {file = "README.md", content-type = "text/markdown"}
1010
license = {file = "LICENSE"}
1111
authors = [
@@ -22,7 +22,8 @@ dependencies = [
2222
"torch<3.0.0,>=2.0.1",
2323
"spacy[transformers]<4.0.0,>=3.3.1",
2424
"spacy-alignments<1.0.0,>=0.8.5",
25-
"pytorch-partial-tagger<1.0.0,>=0.1.14",
25+
"pytorch-partial-tagger<1.0.0,>=0.1.15",
26+
"sequence-label<1.0.0,>=0.1.4",
2627
]
2728
dynamic = ["version"]
2829

@@ -89,7 +90,4 @@ max-complexity = 18
8990
testpaths = [
9091
"tests",
9192
]
92-
addopts = "--strict-markers -m 'not local'"
93-
markers = [
94-
"local"
95-
]
93+
addopts = "--cov=spacy_partial_tagger --cov-report=term-missing -vv"

spacy_partial_tagger/collator.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Optional, Tuple
1+
from typing import List, Optional, Tuple
22

3-
from partial_tagger.data import Alignment, Alignments, Span
43
from partial_tagger.data.collators import BaseCollator, Batch, TransformerCollator
4+
from sequence_label.core import LabelAlignment, Span
55
from transformers import AutoTokenizer
66
from transformers.models.bert_japanese import BertJapaneseTokenizer
77

@@ -23,34 +23,32 @@ def __init__(
2323
}
2424
self.__tokenizer_args["return_offsets_mapping"] = True
2525

26-
def __call__(self, texts: Tuple[str]) -> Tuple[Batch, Alignments]:
26+
def __call__(
27+
self, texts: Tuple[str, ...]
28+
) -> Tuple[Batch, Tuple[LabelAlignment, ...]]:
2729
batch_encoding = self.__tokenizer(texts, **self.__tokenizer_args)
2830

2931
pad_token_id = self.__tokenizer.pad_token_id
3032
mask = batch_encoding.input_ids != pad_token_id
31-
tokenized_text_lengths = mask.sum(dim=1)
3233

3334
alignments = []
34-
for _tokenized_text_length, input_ids, text in zip(
35-
tokenized_text_lengths, batch_encoding.input_ids, texts
36-
):
35+
for input_ids, text in zip(batch_encoding.input_ids, texts):
3736
char_spans = tuple(
38-
Span(span[0], len(span)) if span else None
37+
Span(start=span[0], length=len(span)) if span else None
3938
for span in get_alignments(self.__tokenizer, text, input_ids.tolist())
4039
)
41-
token_indices = [-1] * len(text)
40+
token_spans: List[Optional[Span]] = [None] * len(text)
4241
for token_index, char_span in enumerate(char_spans):
4342
if char_span is None:
4443
continue
4544
start = char_span.start
4645
end = char_span.start + char_span.length
47-
token_indices[start:end] = [token_index] * char_span.length
46+
for i in range(start, end):
47+
token_spans[i] = Span(start=token_index, length=1)
4848

49-
alignments.append(Alignment(text, char_spans, tuple(token_indices)))
49+
alignments.append(LabelAlignment(char_spans, tuple(token_spans)))
5050

51-
return Batch(tagger_inputs=batch_encoding, mask=mask), Alignments(
52-
tuple(alignments)
53-
)
51+
return Batch(tagger_inputs=batch_encoding, mask=mask), tuple(alignments)
5452

5553

5654
def get_collator(

spacy_partial_tagger/pipeline.py

+37-31
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, cast
1+
from typing import Callable, Dict, Iterable, List, Optional, Tuple, cast
22

33
import srsly
44
import torch
5-
from partial_tagger.data import Alignments, LabelSet
6-
from partial_tagger.training import compute_partially_supervised_loss
7-
from partial_tagger.utils import create_tag
5+
from partial_tagger.training import compute_partially_supervised_loss, create_tag_bitmap
6+
from sequence_label import LabelSet, SequenceLabel
87
from spacy import util
98
from spacy.errors import Errors
109
from spacy.language import Language
@@ -39,7 +38,9 @@ def __init__(
3938

4039
@property
4140
def label_set(self) -> LabelSet:
42-
return LabelSet(set(self.cfg["labels"]))
41+
return LabelSet(
42+
labels=set(self.cfg["labels"]), padding_index=self.padding_index
43+
)
4344

4445
def predict(self, docs: List[Doc]) -> Floats2d:
4546
(_, tag_indices) = self.model.predict(docs)
@@ -50,16 +51,14 @@ def set_annotations(
5051
docs: List[Doc],
5152
tag_indices: Floats2d,
5253
) -> None:
53-
alignments = Alignments(tuple(doc.user_data["alignment"] for doc in docs))
54-
tags_batch = alignments.create_char_based_tags(
55-
tag_indices.tolist(),
56-
label_set=self.label_set,
57-
padding_index=self.padding_index,
54+
labels = self.label_set.decode(
55+
tag_indices=tag_indices.tolist(),
56+
alignments=tuple(doc.user_data["alignment"] for doc in docs),
5857
)
5958

60-
for doc, tags in zip(docs, tags_batch):
59+
for doc, label in zip(docs, labels):
6160
ents = []
62-
for tag in tags:
61+
for tag in label.tags:
6362
span = doc.char_span(tag.start, tag.start + tag.length, tag.label)
6463
if span:
6564
ents.append(span)
@@ -89,19 +88,17 @@ def update(
8988
losses[self.name] += loss
9089
return losses
9190

92-
def initialize(
93-
self, get_examples: Callable, *, nlp: Language, labels: Optional[dict] = None
94-
) -> None:
91+
def initialize(self, get_examples: Callable, *, nlp: Language) -> None:
9592
X_small: List[Doc] = []
96-
label: Set[str] = set()
93+
labels: List[str] = []
9794
for example in get_examples():
9895
if len(X_small) < 10:
9996
X_small.append(example.x)
10097
for entity in example.y.ents:
101-
if entity.label_ not in label:
102-
label.add(entity.label_)
98+
if entity.label_ not in labels:
99+
labels.append(entity.label_)
103100

104-
self.cfg["labels"] = list(label)
101+
self.cfg["labels"] = labels
105102

106103
self.model.initialize(
107104
X=X_small,
@@ -113,23 +110,32 @@ def get_loss(
113110
) -> Tuple[float, Floats4d]:
114111
scores_pt = xp2torch(scores, requires_grad=True)
115112

116-
char_based_tags = []
117-
temp = []
113+
labels = []
114+
alignments = []
118115
lengths = []
119116
for example in examples:
120-
tags = tuple(
121-
create_tag(ent.start_char, len(ent.text), ent.label_)
122-
for ent in example.y.ents
117+
labels.append(
118+
SequenceLabel.from_dict(
119+
tags=[
120+
{
121+
"start": ent.start_char,
122+
"end": ent.end_char,
123+
"label": ent.label_,
124+
}
125+
for ent in example.y.ents
126+
],
127+
size=len(example.y.text),
128+
)
123129
)
124-
char_based_tags.append(tags)
125130

126131
alignment = example.x.user_data["alignment"]
127-
lengths.append(alignment.num_tokens)
128-
temp.append(alignment)
132+
alignments.append(alignment)
133+
lengths.append(alignment.target_size)
129134

130-
alignments = Alignments(tuple(temp))
131-
tag_bitmap = torch.tensor(
132-
alignments.get_tag_bitmap(char_based_tags, self.label_set),
135+
tag_bitmap = create_tag_bitmap(
136+
label_set=self.label_set,
137+
labels=tuple(labels),
138+
alignments=tuple(alignments),
133139
device=scores_pt.device,
134140
)
135141

@@ -140,7 +146,7 @@ def get_loss(
140146
)
141147

142148
loss = compute_partially_supervised_loss(
143-
scores_pt, tag_bitmap, mask, self.label_set.get_outside_index()
149+
scores_pt, tag_bitmap, mask, self.label_set.outside_index
144150
)
145151

146152
(grad,) = torch.autograd.grad(loss, scores_pt)

spacy_partial_tagger/tagger.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from functools import partial
22
from typing import Any, Callable, List, Optional, Tuple, cast
33

4-
from partial_tagger.data import LabelSet
4+
from sequence_label import LabelSet
55
from spacy.tokens import Doc
66
from spacy.util import registry
77
from thinc.api import Model, get_torch_default_device, torch2xp, xp2torch
@@ -45,7 +45,7 @@ def forward(
4545
collator = model.attrs["collator"]
4646
batch, alignments = collator(tuple(doc.text for doc in X))
4747

48-
for doc, alignment in zip(X, alignments.alignments):
48+
for doc, alignment in zip(X, alignments):
4949
doc.user_data["alignment"] = alignment
5050

5151
device = get_torch_default_device()

spacy_partial_tagger/util.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from typing import List, Tuple
22

33
import spacy_alignments as tokenizations
4-
from partial_tagger.data import LabelSet
54
from partial_tagger.decoders.viterbi import Constrainer, ViterbiDecoder
65
from partial_tagger.encoders.transformer import TransformerModelEncoderFactory
76
from partial_tagger.tagger import SequenceTagger
7+
from sequence_label import LabelSet
88
from transformers import PreTrainedTokenizer
99

1010

@@ -16,9 +16,9 @@ def create_tagger(
1616
ViterbiDecoder(
1717
padding_index,
1818
Constrainer(
19-
label_set.get_start_states(),
20-
label_set.get_end_states(),
21-
label_set.get_transitions(),
19+
label_set.start_states,
20+
label_set.end_states,
21+
label_set.transitions,
2222
),
2323
),
2424
)

tests/test_tagger.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from partial_tagger.data import LabelSet
1+
from sequence_label import LabelSet
22
from spacy.language import Language
33
from spacy.tokens import Doc
44

@@ -22,6 +22,6 @@ def test_partial_tagger(nlp: Language) -> None:
2222
(log_potentials, tag_indices), _ = tagger(docs, is_train=False)
2323

2424
# 10 is the length of sub-words of text.
25-
num_tags = label_set.get_tag_size()
25+
num_tags = label_set.state_size
2626
assert log_potentials.shape == (1, 10, num_tags, num_tags)
2727
assert tag_indices.shape == (1, 10)

0 commit comments

Comments
 (0)