Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions tensormap-backend/app/services/model_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,39 @@ def model_generation(model_params: dict) -> dict:
len(model_params["edges"]),
)

nodes = model_params.get("nodes", [])
edges = model_params.get("edges", [])
if not nodes:
raise ValueError("Graph must include at least one node")

node_ids = [node["id"] for node in nodes]
if len(set(node_ids)) != len(node_ids):
raise ValueError("Duplicate node IDs are not allowed")

input_ids = [node["id"] for node in nodes if node["type"] == "custominput"]
if not input_ids:
raise ValueError("Graph must include at least one input node")

nodes_by_id = {node["id"]: node for node in nodes}

# Build adjacency maps
source_to_targets = defaultdict(list)
target_to_sources = defaultdict(list)
for edge in model_params["edges"]:
for edge in edges:
if edge["source"] not in nodes_by_id or edge["target"] not in nodes_by_id:
raise ValueError(f"Edge references unknown node(s): source={edge['source']}, target={edge['target']}")
source_to_targets[edge["source"]].append(edge["target"])
target_to_sources[edge["target"]].append(edge["source"])

nodes_by_id = {node["id"]: node for node in model_params["nodes"]}
for input_id in input_ids:
if target_to_sources.get(input_id):
raise ValueError(f"Input node '{input_id}' cannot have incoming edges")
if not source_to_targets.get(input_id):
raise ValueError(f"Input node '{input_id}' must connect to at least one layer")

for node in nodes:
if node["type"] != "custominput" and not target_to_sources.get(node["id"]):
raise ValueError(f"Node '{node['id']}' has no incoming edges")

# BFS from input nodes to build Keras layers in topological order
keras_tensors = {}
Expand Down
40 changes: 40 additions & 0 deletions tensormap-backend/tests/test_model_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,46 @@ def test_single_input_no_edges(self):
with pytest.raises(ValueError):
model_generation(params)

def test_duplicate_node_ids_raise_value_error(self):
params = {
"nodes": [_input_node("dup", [4]), _dense_node("dup", 8, "relu")],
"edges": [],
}
with pytest.raises(ValueError, match="Duplicate node IDs"):
model_generation(params)

def test_edge_with_unknown_node_raises_value_error(self):
params = {
"nodes": [_input_node("in", [4]), _dense_node("out", 1, "linear")],
"edges": [_edge("in", "missing")],
}
with pytest.raises(ValueError, match="unknown node"):
model_generation(params)

def test_non_input_node_without_incoming_edge_raises_value_error(self):
params = {
"nodes": [_input_node("in", [4]), _dense_node("orphan", 8, "relu")],
"edges": [],
}
with pytest.raises(ValueError, match="has no incoming edges"):
model_generation(params)

def test_cycle_detection_raises_value_error(self):
params = {
"nodes": [_input_node("in", [4]), _dense_node("a", 8, "relu"), _dense_node("b", 4, "relu")],
"edges": [_edge("in", "a"), _edge("a", "b"), _edge("b", "a")],
}
with pytest.raises(ValueError, match="disconnected or cyclic"):
model_generation(params)

def test_input_with_incoming_edge_raises_value_error(self):
params = {
"nodes": [_input_node("in1", [4]), _input_node("in2", [4]), _dense_node("h1", 8, "relu")],
"edges": [_edge("in1", "h1"), _edge("h1", "in2")],
}
with pytest.raises(ValueError, match="cannot have incoming edges"):
model_generation(params)

def test_unknown_layer_type_raises_value_error(self):
"""An unsupported node type in the graph must propagate a ValueError."""
params = {
Expand Down
2 changes: 1 addition & 1 deletion tensormap-backend/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading