-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
71 lines (56 loc) · 1.84 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Visualize the bad predictions of a NER model using streamlit.
streamlit run eval.py
"""
from pathlib import Path
from typing import Iterable
import spacy
import streamlit as st
import typer
from spacy.tokens import DocBin
from spacy.training import Example
from spacy_streamlit import visualize_ner
def load_examples(
nlp: spacy.Language | Path | None = None,
doc_bin: DocBin | Path | None = None,
) -> list[Example]:
if nlp is None:
nlp = Path("./models/sm/training/model-last")
if isinstance(nlp, Path):
nlp = spacy.load(nlp)
if doc_bin is None:
doc_bin = Path("./data/dev.spacy")
if isinstance(doc_bin, Path):
doc_bin = DocBin().from_disk(doc_bin)
doc_bin = doc_bin.get_docs(nlp.vocab)
return [Example(nlp(reference.text), reference) for reference in doc_bin]
def is_misprediction(example: Example) -> bool:
labels_ref = [ent.label_ for ent in example.reference.ents]
labels_pred = [ent.label_ for ent in example.predicted.ents]
return labels_ref != labels_pred
def viz_examples(examples: Iterable[Example]):
st.set_page_config(layout="wide")
for i, ex in enumerate(examples):
if not is_misprediction(ex):
continue
st.title(ex.text)
col1, col2 = st.columns(2)
with col1:
visualize_ner(
ex.reference,
show_table=False,
key=f"{i}_ref",
title="Reference",
)
with col2:
visualize_ner(
ex.predicted,
show_table=False,
key=f"{i}_pred",
title="Predicted",
)
st.divider()
def main(nlp: None | Path = None, docbin: None | Path = None):
examples = load_examples(nlp, docbin)
viz_examples(examples)
if __name__ == "__main__":
typer.run(main)