Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tensormap-backend/app/services/model_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,33 @@ def _build_layer(node: dict, input_tensor):
name=name,
)(input_tensor)

elif node_type == "customlstm":
# LSTM requires 3D input: (batch_size, timesteps, features)
if len(input_tensor.shape) != 3:
raise ValueError(
f"LSTM requires 3D input (batch, timesteps, features), "
f"got shape {input_tensor.shape}. Insert a Reshape layer before LSTM."
)

# return_sequences is stored as a string ("true"/"false") from frontend Select.
# Activation params are also strings; convert "none" to "linear" for TensorFlow.
try:
units = int(params.get("units", 0) or 0)
if units <= 0:
raise ValueError("LSTM units must be a positive integer")
except (ValueError, TypeError) as exc:
raise ValueError(f"Invalid LSTM units parameter: {exc}") from exc

activation = params.get("activation", "tanh")
recurrent_activation = params.get("recurrentActivation", "sigmoid")
return tf.keras.layers.LSTM(
units=units,
activation="linear" if activation == "none" else activation,
recurrent_activation="linear" if recurrent_activation == "none" else recurrent_activation,
return_sequences=params.get("returnSequences") in (True, "true", 1),
name=name,
)(input_tensor)

elif node_type == "customdropout":
rate = float(params.get("rate", 0.5))
if not 0.0 <= rate < 1.0:
Expand Down
131 changes: 131 additions & 0 deletions tensormap-backend/tests/test_lstm_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Tests for LSTM layer support in model_generation._build_layer."""

import pytest
import tensorflow as tf

from app.services.model_generation import _build_layer


def _make_node(units, return_sequences="false"):
return {
"id": "lstm-test",
"type": "customlstm",
"data": {"params": {"units": units, "returnSequences": return_sequences}},
}


def test_build_lstm_layer_basic():
input_tensor = tf.keras.Input(shape=(10, 8))
node = _make_node("64")
output = _build_layer(node, input_tensor)
assert output.shape[-1] == 64


def test_build_lstm_return_sequences_false():
input_tensor = tf.keras.Input(shape=(10, 8))
node = _make_node("32", "false")
output = _build_layer(node, input_tensor)
assert len(output.shape) == 2


def test_build_lstm_return_sequences_true():
input_tensor = tf.keras.Input(shape=(10, 8))
node = _make_node("32", "true")
output = _build_layer(node, input_tensor)
assert len(output.shape) == 3


def test_build_lstm_invalid_units_empty():
input_tensor = tf.keras.Input(shape=(10, 8))
node = _make_node("")
with pytest.raises(ValueError, match="Invalid LSTM units"):
_build_layer(node, input_tensor)


def test_build_lstm_invalid_units_zero():
input_tensor = tf.keras.Input(shape=(10, 8))
node = _make_node(0)
with pytest.raises(ValueError, match="LSTM units must be a positive integer"):
_build_layer(node, input_tensor)


def test_build_lstm_return_sequences_true_shape():
input_tensor = tf.keras.Input(shape=(10, 8))
node = _make_node("32", "true")
output = _build_layer(node, input_tensor)
# return_sequences=True preserves the time dimension
assert len(output.shape) == 3
assert output.shape[-1] == 32


def test_build_lstm_invalid_units_negative():
input_tensor = tf.keras.Input(shape=(10, 8))
node = _make_node(-10)
with pytest.raises(ValueError, match="LSTM units must be a positive integer"):
_build_layer(node, input_tensor)


def test_build_lstm_invalid_units_nan():
input_tensor = tf.keras.Input(shape=(10, 8))
node = _make_node("NaN")
with pytest.raises(ValueError):
_build_layer(node, input_tensor)


def test_build_lstm_return_sequences_false_shape():
"""Verify 2D output when return_sequences=False."""
input_tensor = tf.keras.Input(shape=(10, 8))
node = _make_node("64", "false")
output = _build_layer(node, input_tensor)
assert len(output.shape) == 2 # (batch, units)
assert output.shape[-1] == 64


def test_build_lstm_with_activation():
"""Test LSTM with custom activation function."""
input_tensor = tf.keras.Input(shape=(10, 8))
node = {
"id": "lstm-test",
"type": "customlstm",
"data": {"params": {"units": "32", "returnSequences": "false", "activation": "relu"}},
}
output = _build_layer(node, input_tensor)
assert output.shape[-1] == 32


def test_build_lstm_invalid_input_shape_2d():
"""LSTM requires 3D input, should reject 2D."""
input_tensor = tf.keras.Input(shape=(64,)) # 2D input!
node = _make_node("32")
with pytest.raises(ValueError, match="3D input"):
_build_layer(node, input_tensor)


def test_build_lstm_invalid_input_shape_4d():
"""LSTM requires 3D input, should reject 4D."""
input_tensor = tf.keras.Input(shape=(10, 8, 3)) # 4D input!
node = _make_node("32")
with pytest.raises(ValueError, match="3D input"):
_build_layer(node, input_tensor)


def test_build_lstm_with_recurrent_activation():
"""Test LSTM with custom recurrent activation."""
input_tensor = tf.keras.Input(shape=(10, 8))
node = {
"id": "lstm-test",
"type": "customlstm",
"data": {
"params": {"units": "32", "returnSequences": "false", "activation": "tanh", "recurrentActivation": "relu"}
},
}
output = _build_layer(node, input_tensor)
assert output.shape[-1] == 32


def test_build_lstm_large_units():
"""Test LSTM with large unit count."""
input_tensor = tf.keras.Input(shape=(10, 8))
node = _make_node("1024")
output = _build_layer(node, input_tensor)
assert output.shape[-1] == 1024
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import DenseNode from "./CustomNodes/DenseNode/DenseNode";
import FlattenNode from "./CustomNodes/FlattenNode/FlattenNode";
import ConvNode from "./CustomNodes/ConvNode/ConvNode";
import DropoutNode from "./CustomNodes/DropoutNode/DropoutNode";
import LSTMNode from "./CustomNodes/LSTMNode/LSTMNode";
import MaxPoolingNode from "./CustomNodes/MaxPoolingNode/MaxPoolingNode";
import Sidebar from "./Sidebar";
import NodePropertiesPanel from "./NodePropertiesPanel";
Expand All @@ -56,6 +57,7 @@ const nodeTypes = {
customdropout: DropoutNode,
custommaxpool: MaxPoolingNode,
customglobalavgpool: GlobalAvgPoolNode,
customlstm: LSTMNode,
};

function Canvas() {
Expand Down Expand Up @@ -539,6 +541,12 @@ function Canvas() {
customdropout: { rate: "" },
custommaxpool: { pool_size: "", stride: "", padding: "valid" },
customglobalavgpool: {},
customlstm: {
units: "",
returnSequences: "false",
activation: "tanh",
recurrentActivation: "sigmoid",
}, // Stored as strings to match Select component values
};

const newNode = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import PropTypes from "prop-types";
import { Handle, Position } from "reactflow";

function LSTMNode({ data, id }) {
const { units, returnSequences, activation, recurrentActivation } = data.params;
const parsedUnits = Number(units);
const configured = String(units).trim() !== "" && parsedUnits > 0;
const activationLabel = activation || "tanh";
const recurrentLabel = recurrentActivation || "sigmoid";

return (
<div className="w-56 rounded-lg border bg-white shadow-sm">
<Handle type="target" position={Position.Left} isConnectable id={`${id}_in`} />
<div className="rounded-t-lg bg-node-lstm px-3 py-1.5 text-xs font-bold text-white">LSTM</div>
<div className="px-3 py-2 text-xs text-muted-foreground">
{configured
? `Units: ${units}${returnSequences === "true" || returnSequences === true ? " • seq" : ""}`
: "Not configured"}
</div>
{configured && (
<div className="border-t px-3 py-1.5 text-xs text-gray-500">
<div>
{activationLabel} / {recurrentLabel}
</div>
</div>
)}
<Handle type="source" position={Position.Right} isConnectable id={`${id}_out`} />
</div>
);
}

LSTMNode.propTypes = {
data: PropTypes.shape({
params: PropTypes.shape({
units: PropTypes.oneOfType([PropTypes.string, PropTypes.number]),
returnSequences: PropTypes.string,
activation: PropTypes.string,
recurrentActivation: PropTypes.string,
}).isRequired,
}).isRequired,
id: PropTypes.string.isRequired,
};

export default LSTMNode;
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { render, screen } from "@testing-library/react";
import { describe, it, expect, vi } from "vitest";
import LSTMNode from "./LSTMNode";

vi.mock("reactflow", () => ({
Handle: (props) => <div data-testid={`handle-${props.type}-${props.position}`} {...props} />,
Position: { Left: "left", Right: "right", Top: "top", Bottom: "bottom" },
}));

describe("LSTMNode", () => {
const defaultProps = {
id: "test-node-lstm",
data: { params: { units: "", returnSequences: "false" } },
};

it("renders the title correctly", () => {
render(<LSTMNode {...defaultProps} />);
expect(screen.getByText("LSTM")).toBeInTheDocument();
});

it("shows Not configured when units is empty", () => {
render(<LSTMNode {...defaultProps} />);
expect(screen.getByText("Not configured")).toBeInTheDocument();
});

it("shows units when configured", () => {
const props = { ...defaultProps, data: { params: { units: 64, returnSequences: "false" } } };
render(<LSTMNode {...props} />);
expect(screen.getByText("Units: 64")).toBeInTheDocument();
});

it("shows seq suffix when returnSequences is true", () => {
const props = { ...defaultProps, data: { params: { units: 32, returnSequences: "true" } } };
render(<LSTMNode {...props} />);
expect(screen.getByText("Units: 32 \u00b7 seq")).toBeInTheDocument();
});

it("renders target handle on the left", () => {
render(<LSTMNode {...defaultProps} />);
expect(screen.getByTestId("handle-target-left")).toBeInTheDocument();
});

it("renders source handle on the right", () => {
render(<LSTMNode {...defaultProps} />);
expect(screen.getByTestId("handle-source-right")).toBeInTheDocument();
});

it("shows Not configured for non-numeric units", () => {
const props = { ...defaultProps, data: { params: { units: "abc", returnSequences: "false" } } };
render(<LSTMNode {...props} />);
expect(screen.getByText("Not configured")).toBeInTheDocument();
});

it("shows Not configured for zero units", () => {
const props = { ...defaultProps, data: { params: { units: 0, returnSequences: "false" } } };
render(<LSTMNode {...props} />);
expect(screen.getByText("Not configured")).toBeInTheDocument();
});

it("shows Not configured for negative units", () => {
const props = { ...defaultProps, data: { params: { units: -64, returnSequences: "false" } } };
render(<LSTMNode {...props} />);
expect(screen.getByText("Not configured")).toBeInTheDocument();
});

it("handles missing params gracefully", () => {
const props = { id: "test", data: { params: {} } };
expect(() => render(<LSTMNode {...props} />)).not.toThrow();
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ export const canSaveModel = (modelName, modelData) => {
if (!p.pool_size || !p.stride) {
return false;
}
} else if (node.type === "customlstm") {
const units = Number(node.data.params.units);
if (!units || units <= 0 || units > 10000 || isNaN(units)) {
return false; // Reject values outside [1, 10000]
}
} else if (node.type === "customdropout") {
const rate = parseFloat(node.data.params.rate);
if (node.data.params.rate === "" || isNaN(rate) || rate < 0 || rate >= 1) {
Expand Down
Loading
Loading