Skip to content

Commit 3f61b59

Browse files
make lemma_graph undirected
1 parent 73ee61b commit 3f61b59

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

Diff for: pytextrank/base.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def __init__ (
309309
# effectively, performs the same work as the `reset()` method;
310310
# called explicitly here for the sake of type annotations
311311
self.elapsed_time: float = 0.0
312-
self.lemma_graph: nx.DiGraph = nx.DiGraph()
312+
self.lemma_graph: nx.Graph = nx.Graph()
313313
self.phrases: typing.List[Phrase] = []
314314
self.ranks: typing.Dict[Lemma, float] = {}
315315
self.seen_lemma: typing.Dict[Lemma, typing.Set[int]] = OrderedDict()
@@ -323,7 +323,7 @@ def reset (
323323
removing any pre-existing state.
324324
"""
325325
self.elapsed_time = 0.0
326-
self.lemma_graph = nx.DiGraph()
326+
self.lemma_graph = nx.Graph()
327327
self.phrases = []
328328
self.ranks = {}
329329
self.seen_lemma = OrderedDict()
@@ -400,15 +400,15 @@ def get_personalization ( # pylint: disable=R0201
400400

401401
def _construct_graph (
402402
self
403-
) -> nx.DiGraph:
403+
) -> nx.Graph:
404404
"""
405405
Construct the
406406
[*lemma graph*](https://derwen.ai/docs/ptr/glossary/#lemma-graph).
407407
408408
returns:
409409
a directed graph representing the lemma graph
410410
"""
411-
g = nx.DiGraph()
411+
g = nx.Graph()
412412

413413
# add nodes made of Lemma(lemma, pos)
414414
g.add_nodes_from(self.node_list)
@@ -571,6 +571,8 @@ def _calc_discounted_normalised_rank (
571571
returns:
572572
normalized rank metric
573573
"""
574+
if len(span) < 1 :
575+
return 0.0
574576
non_lemma = len([tok for tok in span if tok.pos_ not in self.pos_kept])
575577
non_lemma_discount = len(span) / (len(span) + (2.0 * non_lemma) + 1.0)
576578

@@ -877,7 +879,7 @@ def write_dot (
877879
path:
878880
path for the output file; defaults to `"graph.dot"`
879881
"""
880-
dot = graphviz.Digraph()
882+
dot = graphviz.Graph()
881883

882884
for lemma in self.lemma_graph.nodes():
883885
rank = self.ranks[lemma]

Diff for: tests/test_base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ def test_stop_words ():
154154
for phrase in doc._.phrases[:5]
155155
]
156156

157-
assert "words" in phrases
157+
assert "sentences" in phrases
158158

159159
# add `"word": ["NOUN"]` to the *stop words*, to remove instances
160160
# of `"word"` or `"words"` then see how the ranked phrases differ?
161161

162162
nlp2 = spacy.load("en_core_web_sm")
163-
nlp2.add_pipe("textrank", config={ "stopwords": { "word": ["NOUN"] } })
163+
nlp2.add_pipe("textrank", config={ "stopwords": { "sentence": ["NOUN"] } })
164164

165165
with open("dat/gen.txt", "r") as f:
166166
doc = nlp2(f.read())

0 commit comments

Comments
 (0)