Skip to content

Commit

Permalink
Merge pull request #1 from nodestream-proj/improvements
Browse files Browse the repository at this point in the history
Improving code quality and adding tests
  • Loading branch information
zprobst authored Nov 13, 2024
2 parents 38f9147 + 6657c1d commit 45706e2
Show file tree
Hide file tree
Showing 11 changed files with 473 additions and 267 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v3
Expand Down
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
5 changes: 3 additions & 2 deletions nodestream_plugin_semantic/chunk.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from typing import Iterable

from nodestream.subclass_registry import SubclassRegistry
from nodestream.pluggable import Pluggable
from nodestream.subclass_registry import SubclassRegistry

from .model import Content

Expand All @@ -25,7 +25,8 @@ def from_file_data(type, **chunker_kwargs) -> "Chunker":
return CHUNKER_SUBCLASS_REGISTRY.get(type)(**chunker_kwargs)

@abstractmethod
def chunk(self, content: Content) -> Iterable[Content]: ...
def chunk(self, content: Content) -> Iterable[Content]:
...


class SplitOnDelimiterChunker(Chunker):
Expand Down
3 changes: 1 addition & 2 deletions nodestream_plugin_semantic/content_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from abc import ABC, abstractmethod
from typing import Iterable
from pathlib import Path
from typing import Iterable

from nodestream.subclass_registry import SubclassRegistry


CONTENT_TYPE_SUBCLASS_REGISTRY = SubclassRegistry()
PLAIN_TEXT_ALIAS = "plain_text"
PLAIN_TEXT_EXTENSIONS = {".txt", ".md"}
Expand Down
3 changes: 1 addition & 2 deletions nodestream_plugin_semantic/embed.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from abc import ABC, abstractmethod

from nodestream.subclass_registry import SubclassRegistry
from nodestream.pluggable import Pluggable
from nodestream.subclass_registry import SubclassRegistry

from .model import Content, Embedding


EMBEDDER_SUBCLASS_REGISTRY = SubclassRegistry()


Expand Down
10 changes: 5 additions & 5 deletions nodestream_plugin_semantic/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from dataclasses import dataclass
import hashlib
from typing import List, Optional, Iterable

from nodestream.model import DesiredIngestion, Node
from dataclasses import dataclass
from typing import Iterable, List, Optional

from nodestream.model import DesiredIngestion, Node, Relationship

Embedding = List[float | int]
CONTENT_NODE_TYPE_ID_PROPERTY = "id"
Expand Down Expand Up @@ -66,8 +65,9 @@ def make_ingestible(

if self.parent:
self.parent.apply_to_node(node_type, related := Node())
relationship = Relationship(type=relationship_type)
ingest.add_relationship(
related_node=related, type=relationship_type, outbound=False
related_node=related, relationship=relationship, outbound=False
)

return ingest
10 changes: 5 additions & 5 deletions nodestream_plugin_semantic/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from glob import glob
from pathlib import Path
from typing import Optional, List, Dict
from typing import Dict, List, Optional

from nodestream.model import DesiredIngestion
from nodestream.pipeline import Extractor, Transformer
from nodestream.pipeline.value_providers import (
ValueProvider,
JmespathValueProvider,
ProviderContext,
ValueProvider,
)
from nodestream.schema import (
Cardinality,
ExpandsSchema,
SchemaExpansionCoordinator,
GraphObjectSchema,
Cardinality,
SchemaExpansionCoordinator,
)

from .chunk import Chunker
from .content_types import PLAIN_TEXT_ALIAS, ContentType
from .embed import Embedder
from .content_types import ContentType, PLAIN_TEXT_ALIAS
from .model import Content

DEFAULT_ID = JmespathValueProvider.from_string_expression("id")
Expand Down
529 changes: 280 additions & 249 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,16 @@ nodestream = "^0.13.0"
semchunk = "^2.2.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest = "^8.2.0"
pytest-mock = "^3.11.1"
ruff = "^0.0.284"
isort = "^5.12.0"
black = "^23.7.0"
pytest-cov = "^4.1.0"
pytest-asyncio = "^0.24.0"

[tool.isort]
profile = "black"

[build-system]
requires = ["poetry-core"]
Expand Down
65 changes: 65 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from unittest.mock import Mock

from nodestream.model import DesiredIngestion, Node

from nodestream_plugin_semantic.model import Content, hash


def test_content_from_text():
content_text = "test content"
content = Content.from_text(content_text)
assert content.content == content_text
assert content.id == hash(content_text)
assert content.parent is None


def test_content_add_metadata():
content = Content.from_text("test content")
content.add_metadata("key", "value")
assert content.metadata == {"key": "value"}


def test_content_split_on_delimiter():
content_text = "line1\nline2\nline3"
content = Content.from_text(content_text)
lines = list(content.split_on_delimiter("\n"))
assert len(lines) == 3
assert lines[0].content == "line1"
assert lines[1].content == "line2"
assert lines[2].content == "line3"
assert all(line.parent == content for line in lines)


def test_content_assign_embedding():
content = Content.from_text("test content")
embedding = [0.1, 0.2, 0.3]
content.assign_embedding(embedding)
assert content.embedding == embedding


def test_content_apply_to_node():
content = Content.from_text("test content")
node = Mock(spec=Node)
content.apply_to_node("test_type", node)
node.type = "test_type"
node.key_values.set_property.assert_called_with("id", content.id)
node.properties.set_property.assert_any_call("content", content.content)


def test_content_make_ingestible():
parent_content = Content.from_text("parent content")
child_content = Content.from_text("child content", parent=parent_content)
ingest = child_content.make_ingestible("test_type", "test_relationship")

assert isinstance(ingest, DesiredIngestion)
assert ingest.source.type == "test_type"
ingest.source.key_values == {"id": child_content.id}
ingest.source.properties == {"content": child_content.content}

assert len(ingest.relationships) == 1
relationship = ingest.relationships[0]
assert relationship.relationship.type == "test_relationship"
assert relationship.outbound is False
assert relationship.to_node.type == "test_type"
relationship.to_node.key_values == {"id": parent_content.id}
relationship.to_node.properties == {"content": parent_content.content}
100 changes: 100 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from nodestream.model import DesiredIngestion

from nodestream_plugin_semantic.model import Content
from nodestream_plugin_semantic.pipeline import (
ChunkContent,
ContentInterpreter,
ConvertToContent,
DocumentExtractor,
EmbedContent,
)


@pytest.mark.asyncio
async def test_chunk_content():
chunker = MagicMock()
chunker.chunk.return_value = [
Content(id="1", content="chunk1"),
Content(id="2", content="chunk2"),
]
transformer = ChunkContent(chunker)
record = Content(id="0", content="original content")
chunks = [chunk async for chunk in transformer.transform_record(record)]
assert len(chunks) == 2
assert chunks[0].content == "chunk1"
assert chunks[1].content == "chunk2"


@pytest.mark.asyncio
async def test_embed_content():
embedder = AsyncMock()
embedder.embed.return_value = "embedded content"
transformer = EmbedContent(embedder)
content = Content(id="0", content="original content")
result = await transformer.transform_record(content)
assert result.content == "original content"
assert result.embedding == "embedded content"


def test_document_extractor():
paths = [Path("file1.txt"), Path("file2.txt")]
content_type = MagicMock()
content_type.is_supported.return_value = True
content_type.read.return_value = "file content"
with patch(
"nodestream_plugin_semantic.pipeline.glob",
return_value=["file1.txt", "file2.txt"],
), patch(
"nodestream_plugin_semantic.pipeline.ContentType.by_name",
return_value=content_type,
):
extractor = DocumentExtractor.from_file_data(globs=["*.txt"])
assert len(extractor.paths) == 2
assert extractor.read(paths[0]) == "file content"


@pytest.mark.asyncio
async def test_document_extractor_extract_records():
content_type = MagicMock()
content_type.is_supported.return_value = True
content_type.read.return_value = "file content"
with patch(
"nodestream_plugin_semantic.pipeline.glob", return_value=["file1.txt"]
), patch(
"nodestream_plugin_semantic.pipeline.ContentType.by_name",
return_value=content_type,
):
extractor = DocumentExtractor.from_file_data(globs=["*.txt"])
records = [record async for record in extractor.extract_records()]
assert len(records) == 1
assert records[0].content == "file content"


@pytest.mark.asyncio
async def test_convert_to_content():
record = {"id": "1", "content": "some content"}
transformer = ConvertToContent()
content = await transformer.transform_record(record)
assert content.id == "1"
assert content.content == "some content"


@pytest.mark.asyncio
async def test_content_interpreter():
content = Content(id="1", content="some content")
transformer = ContentInterpreter()
desired_ingestion = await transformer.transform_record(content)
assert isinstance(desired_ingestion, DesiredIngestion)


def test_content_interpreter_expand_schema():
transformer = ContentInterpreter()
coordinator = MagicMock()
transformer.expand_schema(coordinator)
coordinator.on_node_schema.assert_called()
coordinator.on_relationship_schema.assert_called()
coordinator.connect.assert_called()

0 comments on commit 45706e2

Please sign in to comment.