Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tensormap-backend/app/services/model_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def _build_layer(node: dict, input_tensor):
elif node_type == "customflatten":
return tf.keras.layers.Flatten(name=name)(input_tensor)

elif node_type == "customdropout":
return tf.keras.layers.Dropout(
rate=float(params.get("rate", 0.5)),
name=name,
)(input_tensor)

elif node_type == "custommaxpool":
return tf.keras.layers.MaxPooling2D(
pool_size=int(params.get("pool_size", 2)),
Expand All @@ -103,6 +109,12 @@ def _build_layer(node: dict, input_tensor):
elif node_type == "customglobalavgpool":
return tf.keras.layers.GlobalAveragePooling2D(name=name)(input_tensor)

elif node_type == "custombatchnorm":
return tf.keras.layers.BatchNormalization(
momentum=float(params.get("momentum", 0.99)),
epsilon=float(params.get("epsilon", 0.001)),
name=name,
)(input_tensor)
elif node_type == "customconv":
activation = params["activation"]
return tf.keras.layers.Conv2D(
Expand Down
41 changes: 32 additions & 9 deletions tensormap-backend/tests/test_model_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,19 +304,42 @@ def test_maxpool_default_params(self):

def test_maxpool_in_model(self):
"""input → conv → maxpool → flatten → dense end-to-end."""


def _batchnorm_node(node_id: str, momentum: float = 0.99, epsilon: float = 0.001) -> dict:
return {
"id": node_id,
"type": "custombatchnorm",
"data": {"params": {"momentum": momentum, "epsilon": epsilon}},
}


class TestBatchNormLayer:
"""Unit and integration tests for the BatchNormalization layer."""

def test_batchnorm_output_shape(self):
input_t = tf.keras.Input(shape=(28, 28, 16), name="inp")
node = _batchnorm_node("bn1")
output = _build_layer(node, input_t)
assert output.shape == (None, 28, 28, 16)

def test_batchnorm_default_params(self):
input_t = tf.keras.Input(shape=(10,), name="inp")
node = _batchnorm_node("bn1")
output = _build_layer(node, input_t)
assert output is not None

def test_batchnorm_in_model(self):
"""input → batchnorm → dense end-to-end."""
params = {
"nodes": [
_input_node("x", [28, 28, 1]),
_conv_node("c1", filters=16, kernel=(3, 3), stride=(1, 1), padding="same"),
_maxpool_node("mp1"),
_flatten_node("flat"),
_dense_node("out", 10, "softmax"),
_input_node("x", [16]),
_batchnorm_node("bn1"),
_dense_node("out", 4, "softmax"),
],
"edges": [
_edge("x", "c1"),
_edge("c1", "mp1"),
_edge("mp1", "flat"),
_edge("flat", "out"),
_edge("x", "bn1"),
_edge("bn1", "out"),
],
}
result = model_generation(params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import FlattenNode from "./CustomNodes/FlattenNode/FlattenNode";
import ConvNode from "./CustomNodes/ConvNode/ConvNode";
import DropoutNode from "./CustomNodes/DropoutNode/DropoutNode";
import MaxPoolingNode from "./CustomNodes/MaxPoolingNode/MaxPoolingNode";
import BatchNormalizationNode from "./CustomNodes/BatchNormalizationNode/BatchNormalizationNode";
import Sidebar from "./Sidebar";
import NodePropertiesPanel from "./NodePropertiesPanel";
import { canSaveModel, generateModelJSON } from "./Helpers";
Expand All @@ -56,6 +57,7 @@ const nodeTypes = {
customdropout: DropoutNode,
custommaxpool: MaxPoolingNode,
customglobalavgpool: GlobalAvgPoolNode,
custombatchnorm: BatchNormalizationNode,
};

function Canvas() {
Expand Down Expand Up @@ -539,6 +541,7 @@ function Canvas() {
customdropout: { rate: "" },
custommaxpool: { pool_size: "", stride: "", padding: "valid" },
customglobalavgpool: {},
custombatchnorm: { momentum: "0.99", epsilon: "0.001" },
};

const newNode = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import PropTypes from "prop-types";
import { Handle, Position } from "reactflow";

function BatchNormalizationNode({ data, id }) {
const { momentum, epsilon } = data.params;
const isConfigured =
momentum !== "" && momentum !== undefined && epsilon !== "" && epsilon !== undefined;
return (
<div className="w-44 rounded-lg border bg-white shadow-sm">
<Handle type="target" position={Position.Left} isConnectable id={`${id}_in`} />
<div className="rounded-t-lg bg-node-batchnorm px-3 py-1.5 text-xs font-bold text-white">
BatchNorm
</div>
<div className="px-3 py-2 text-xs text-muted-foreground">
{isConfigured ? `Momentum: ${momentum} | Epsilon: ${epsilon}` : "Not configured"}
</div>
<Handle type="source" position={Position.Right} isConnectable id={`${id}_out`} />
</div>
);
}

BatchNormalizationNode.propTypes = {
data: PropTypes.shape({
params: PropTypes.shape({
momentum: PropTypes.oneOfType([PropTypes.string, PropTypes.number]),
epsilon: PropTypes.oneOfType([PropTypes.string, PropTypes.number]),
}).isRequired,
}).isRequired,
id: PropTypes.string.isRequired,
};

export default BatchNormalizationNode;
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { render, screen } from "@testing-library/react";
import { ReactFlowProvider } from "reactflow";
import BatchNormalizationNode from "./BatchNormalizationNode";

describe("BatchNormalizationNode", () => {
it("renders not configured when params are empty", () => {
render(
<ReactFlowProvider>
<BatchNormalizationNode id="1" data={{ params: { momentum: "", epsilon: "" } }} />
</ReactFlowProvider>,
);
expect(screen.getByText("Not configured")).toBeInTheDocument();
});

it("renders params when configured", () => {
render(
<ReactFlowProvider>
<BatchNormalizationNode id="1" data={{ params: { momentum: 0.99, epsilon: 0.001 } }} />
</ReactFlowProvider>,
);
expect(screen.getByText("Momentum: 0.99 | Epsilon: 0.001")).toBeInTheDocument();
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@ export const canSaveModel = (modelName, modelData) => {
if (node.data.params.rate === "" || isNaN(rate) || rate < 0 || rate >= 1) {
return false;
}
} else if (node.type === "custombatchnorm") {
const p = node.data.params;
if (!p.momentum || !p.epsilon) {
return false;
}
}
// customflatten and customdropout have no required params to validate
}
return isGraphConnected(modelData);
};

const isGraphConnected = (graph) => {
if (!graph.nodes || graph.nodes.length === 0) return false;
const visited = new Set();
Expand All @@ -66,7 +70,6 @@ const isGraphConnected = (graph) => {
}
return visited.size === graph.nodes.length;
};

/**
* Strips visual-only properties from a ReactFlow graph snapshot using
* immutable operations (no destructive delete mutations).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,41 @@ function NodePropertiesPanel({
);
}

if (type === "custombatchnorm") {
return (
<Card className="h-fit">
<CardHeader>
<CardTitle className="text-sm">BatchNorm Layer</CardTitle>
</CardHeader>
<CardContent className="space-y-3">
<div className="space-y-1">
<Label>Momentum</Label>
<Input
type="number"
min="0"
max="1"
step="0.01"
placeholder="Momentum"
value={params.momentum}
onChange={(e) => updateParam("momentum", e.target.value)}
/>
</div>
<div className="space-y-1">
<Label>Epsilon</Label>
<Input
type="number"
min="0"
step="0.0001"
placeholder="Epsilon"
value={params.epsilon}
onChange={(e) => updateParam("epsilon", e.target.value)}
/>
</div>
</CardContent>
</Card>
);
}

return null;
}

Expand Down
10 changes: 9 additions & 1 deletion tensormap-frontend/src/components/DragAndDropCanvas/Sidebar.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function Sidebar() {
Dropout
</div>
<div
className="cursor-grab rounded-md border border-l-4 border-l-node-conv bg-white px-3 py-2 text-xs font-medium"
className="cursor-grab rounded-md border border-l-4 border-l-node-maxpool bg-white px-3 py-2 text-xs font-medium"
onDragStart={(e) => onDragStart(e, "custommaxpool")}
draggable
>
Expand All @@ -67,6 +67,14 @@ function Sidebar() {
>
GlobalAvgPool2D
</div>

<div
className="cursor-grab rounded-md border border-l-4 border-l-node-batchnorm bg-white px-3 py-2 text-xs font-medium"
onDragStart={(e) => onDragStart(e, "custombatchnorm")}
draggable
>
BatchNorm
</div>
</CardContent>
</Card>
);
Expand Down
2 changes: 0 additions & 2 deletions tensormap-frontend/src/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@
--chart-4: 43 74% 66%;
--chart-5: 27 87% 67%;
}

* {
@apply border-border;
}

body {
@apply bg-background text-foreground;
}
Expand Down
1 change: 1 addition & 0 deletions tensormap-frontend/tailwind.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export default {
"node-conv": { DEFAULT: "rgb(255, 128, 43)", header: "rgb(255, 128, 43)" },
"node-dropout": { DEFAULT: "rgb(220, 80, 80)", header: "rgb(180, 50, 50)" },
"node-maxpool": { DEFAULT: "rgb(34, 182, 176)", header: "rgb(20, 140, 135)" },
"node-batchnorm": { DEFAULT: "rgb(140, 90, 200)", header: "rgb(110, 60, 170)" },
},
borderRadius: {
lg: "var(--radius)",
Expand Down
Loading