Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions awswrangler/neptune/_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def to_property_graph(
... )
"""
# check if ~id and ~label column exist and if not throw error
g = gremlin.traversal().withGraph(gremlin.Graph())
g = gremlin.Graph().traversal()
is_edge_df = False
is_update_df = True
if "~id" in df.columns:
Expand Down Expand Up @@ -203,6 +203,24 @@ def to_property_graph(
return _run_gremlin_insert(client, g)


# SPARQL 1.1 IRIREF grammar: '<' ([^<>"{}|^`\]-[#x00-#x20])* '>'
# A cell value spliced between '<' and '>' must contain only the characters allowed
# inside the IRIREF token. Anything else can close the token and inject arbitrary
# SPARQL UPDATE syntax (DELETE / DROP / LOAD / ...).
_IRIREF_INNER_RE = re.compile(r"^[^\x00-\x20<>\"{}|^`\\]*$")


def _validate_iriref_cell(value: Any, column: str, row_index: int) -> str:
text = str(value)
if not _IRIREF_INNER_RE.match(text):
raise exceptions.InvalidArgumentValue(
f"Value in column {column!r} at row index {row_index} is not a valid IRI: "
f"{text!r}. Cells written by `to_rdf_graph` must conform to the SPARQL "
'IRIREF grammar (no whitespace, control characters, or any of <>"{}|^`\\).'
)
return text


@_utils.check_optional_dependency(sparql, "SPARQLWrapper")
def to_rdf_graph(
client: NeptuneClient,
Expand Down Expand Up @@ -267,14 +285,18 @@ def to_rdf_graph(
query = ""
# Loop through items in the DF
for i, (_, row) in enumerate(df.iterrows()):
subject = _validate_iriref_cell(row[subject_column], subject_column, i)
predicate = _validate_iriref_cell(row[predicate_column], predicate_column, i)
obj = _validate_iriref_cell(row[object_column], object_column, i)
# build up a query
if is_quads:
insert = f"""INSERT DATA {{ GRAPH <{row[graph_column]}> {{<{row[subject_column]}>
<{str(row[predicate_column])}> <{row[object_column]}> . }} }}; """
graph = _validate_iriref_cell(row[graph_column], graph_column, i)
insert = f"""INSERT DATA {{ GRAPH <{graph}> {{<{subject}>
<{predicate}> <{obj}> . }} }}; """
query = query + insert
else:
insert = f"""INSERT DATA {{ <{row[subject_column]}> <{str(row[predicate_column])}>
<{row[object_column]}> . }}; """
insert = f"""INSERT DATA {{ <{subject}> <{predicate}>
<{obj}> . }}; """
query = query + insert
# run the query
if i > 0 and i % batch_size == 0:
Expand Down
139 changes: 127 additions & 12 deletions tests/unit/test_neptune_parsing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import logging
from unittest.mock import MagicMock

import pytest # type: ignore
from gremlin_python.process.traversal import T
from gremlin_python.structure.graph import Edge, Path, Property, Vertex, VertexProperty

import awswrangler as wr
import awswrangler.pandas as pd
from awswrangler import exceptions
from awswrangler.neptune._client import NeptuneClient

logging.getLogger("awswrangler").setLevel(logging.DEBUG)

Expand All @@ -29,7 +32,7 @@ def test_parse_gremlin_vertex_elements(gremlin_parser):
assert df.shape == (1, 3)
assert row["id"] == "foo"
assert row["label"] == "vertex"
assert row["properties"] is None
assert not row["properties"] # gremlinpython <3.8 returns None, >=3.8 returns []

# parse multiple vertex elements
v1 = Vertex("bar")
Expand All @@ -40,7 +43,7 @@ def test_parse_gremlin_vertex_elements(gremlin_parser):
assert df.shape == (2, 3)
assert row["id"] == "bar"
assert row["label"] == "vertex"
assert row["properties"] is None
assert not row["properties"] # gremlinpython <3.8 returns None, >=3.8 returns []


# parse Edge elements
Expand All @@ -56,7 +59,7 @@ def test_parse_gremlin_edge_elements(gremlin_parser):
assert row["outV"] == "out1"
assert row["label"] == "label"
assert row["inV"] == "in1"
assert row["properties"] is None
assert not row["properties"] # gremlinpython <3.8 returns None, >=3.8 returns []

# parse multiple edge elements
v1 = Edge("bar", "out1", "label", "in2")
Expand All @@ -69,7 +72,7 @@ def test_parse_gremlin_edge_elements(gremlin_parser):
assert row["outV"] == "out1"
assert row["label"] == "label"
assert row["inV"] == "in2"
assert row["properties"] is None
assert not row["properties"] # gremlinpython <3.8 returns None, >=3.8 returns []


# parse Property elements
Expand All @@ -86,7 +89,7 @@ def test_parse_gremlin_property_elements(gremlin_parser):
assert row["value"] == "bar"
assert row["key"] == "name"
assert row["vertex"] == "v1"
assert row["properties"] is None
assert not row["properties"] # gremlinpython <3.8 returns None, >=3.8 returns []

v = Property("foo", "name", "bar")
input = [v]
Expand All @@ -100,6 +103,11 @@ def test_parse_gremlin_property_elements(gremlin_parser):


# parse Path elements
def _normalize_properties(d: dict) -> dict:
# gremlinpython <3.8 returns properties=None, >=3.8 returns properties=[].
return {k: (None if k == "properties" and not v else v) for k, v in d.items()}


def test_parse_gremlin_path_elements(gremlin_parser):
# parse path with elements
v = Vertex("foo")
Expand All @@ -110,9 +118,15 @@ def test_parse_gremlin_path_elements(gremlin_parser):
df = pd.DataFrame.from_records(out)
row = df.iloc[0]
assert df.shape == (1, 3)
assert row[0] == {"id": "foo", "label": "vertex", "properties": None}
assert row[1] == {"id": "e1", "label": "label", "outV": "foo", "inV": "bar", "properties": None}
assert row[2] == {"id": "bar", "label": "vertex", "properties": None}
assert _normalize_properties(row[0]) == {"id": "foo", "label": "vertex", "properties": None}
assert _normalize_properties(row[1]) == {
"id": "e1",
"label": "label",
"outV": "foo",
"inV": "bar",
"properties": None,
}
assert _normalize_properties(row[2]) == {"id": "bar", "label": "vertex", "properties": None}

# parse path with multiple elements
e2 = Edge("bar", "out1", "label", "in2")
Expand All @@ -122,9 +136,15 @@ def test_parse_gremlin_path_elements(gremlin_parser):
df = pd.DataFrame.from_records(out)
row = df.iloc[1]
assert df.shape == (2, 3)
assert row[0] == {"id": "bar", "label": "vertex", "properties": None}
assert row[1] == {"id": "bar", "label": "label", "outV": "out1", "inV": "in2", "properties": None}
assert row[2] == {"id": "in2", "label": "vertex", "properties": None}
assert _normalize_properties(row[0]) == {"id": "bar", "label": "vertex", "properties": None}
assert _normalize_properties(row[1]) == {
"id": "bar",
"label": "label",
"outV": "out1",
"inV": "in2",
"properties": None,
}
assert _normalize_properties(row[2]) == {"id": "in2", "label": "vertex", "properties": None}

# parse path with maps
p = Path(
Expand Down Expand Up @@ -152,7 +172,13 @@ def test_parse_gremlin_path_elements(gremlin_parser):
assert df.shape == (1, 3)
assert row[0]["name"] == "foo"
assert row[0]["age"] == 29
assert row[1] == {"id": "bar", "label": "label", "outV": "out1", "inV": "in2", "properties": None}
assert _normalize_properties(row[1]) == {
"id": "bar",
"label": "label",
"outV": "out1",
"inV": "in2",
"properties": None,
}
assert row[2]["name"] == "bar"
assert row[2]["age"] == 40

Expand Down Expand Up @@ -216,3 +242,92 @@ def test_parse_gremlin_subgraph(gremlin_parser):
assert df.shape == (1, 2)
assert row["@type"] == "tinker:graph"
assert row["@value"] == {"vertices": ["v[45]", "v[9]"], "edges": ["e[3990][9-route->45]"]}


# to_rdf_graph IRIREF validation: caller-supplied DataFrame cells must conform to
# the SPARQL IRIREF grammar so they cannot close the <...> token and inject
# arbitrary SPARQL UPDATE syntax (DELETE / DROP / LOAD / ...).


def _rdf_triples_df() -> pd.DataFrame:
return pd.DataFrame(
{
"s": ["http://example.org/alice", "http://example.org/bob"],
"p": ["http://xmlns.com/foaf/0.1/name", "http://xmlns.com/foaf/0.1/name"],
"o": ["http://example.org/AliceName", "http://example.org/BobName"],
}
)


def _mock_neptune_client() -> MagicMock:
client = MagicMock(spec=NeptuneClient)
client.write_sparql.return_value = True
return client


def test_to_rdf_graph_accepts_well_formed_iris():
client = _mock_neptune_client()
df = _rdf_triples_df()

assert wr.neptune.to_rdf_graph(client, df) is True
client.write_sparql.assert_called_once()
query = client.write_sparql.call_args.args[0]
assert "<http://example.org/alice>" in query
assert "<http://example.org/bob>" in query


@pytest.mark.parametrize(
"malicious_cell, column",
[
# Bug-bounty PoC payload: closes IRI, runs DELETE WHERE, reopens INSERT.
(
"> . }; DELETE WHERE { ?s ?p ?o }; "
"INSERT DATA { <http://evil.com/x> <http://evil.com/y> <http://evil.com/z",
"o",
),
# DROP ALL via the subject slot.
("http://x.com/s> <http://x.com/p> <http://x.com/o> . }; DROP ALL ; INSERT DATA { <a", "s"),
# LOAD via the predicate slot.
(
"http://x.com/p> <http://x.com/o> . }; LOAD <http://evil.com/payload.ttl> ; "
"INSERT DATA { <http://x.com/s> <http://x.com/p2",
"p",
),
# Whitespace alone is enough to break the IRIREF token.
("http://example.org/ has space", "o"),
("http://example.org/a\n<http://x>", "o"),
("http://example.org/<inner>", "s"),
],
)
def test_to_rdf_graph_rejects_malicious_cells(malicious_cell, column):
client = _mock_neptune_client()
df = _rdf_triples_df()
df.loc[0, column] = malicious_cell

with pytest.raises(exceptions.InvalidArgumentValue, match="not a valid IRI"):
wr.neptune.to_rdf_graph(client, df)
# Validation must run before any network call.
client.write_sparql.assert_not_called()


def test_to_rdf_graph_rejects_malicious_graph_column_for_quads():
client = _mock_neptune_client()
df = _rdf_triples_df()
df["g"] = ["http://example.org/g1", "http://example.org/g2"]
df.loc[0, "g"] = "http://x> {} }; DROP ALL ; INSERT DATA { GRAPH <http://x> { <a"

with pytest.raises(exceptions.InvalidArgumentValue, match="'g'"):
wr.neptune.to_rdf_graph(client, df)
client.write_sparql.assert_not_called()


def test_to_rdf_graph_error_identifies_row_and_column():
client = _mock_neptune_client()
df = _rdf_triples_df()
df.loc[1, "o"] = "http://example.org/bad value"

with pytest.raises(exceptions.InvalidArgumentValue) as exc_info:
wr.neptune.to_rdf_graph(client, df)
message = str(exc_info.value)
assert "'o'" in message
assert "row index 1" in message
Loading