From 7a05c3221c5268b5d924729c8edd258197a0ff50 Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Sat, 26 Apr 2025 18:19:38 +0000 Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20change?= =?UTF-8?q?s=20to=20main=20this=20commit=20is=20based=20on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 [skip ci] --- gematria/granite/python/gnn_model_base.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/gematria/granite/python/gnn_model_base.py b/gematria/granite/python/gnn_model_base.py index dd35a5a1..f35b32cb 100644 --- a/gematria/granite/python/gnn_model_base.py +++ b/gematria/granite/python/gnn_model_base.py @@ -30,25 +30,6 @@ import tf_keras -def _add_batch_dimension(shape: Sequence[int]) -> Sequence[Optional[int]]: - """Adds a batch dimension as the first dimension to a given shape. - - Args: - shape: The shape to which the dimension is added. The size in all dimensions - must be at least one. Empty shape (i.e. the shape of scalar values) is - allowed and it will produce a 1D tensor. - - Returns: - The shape with the batch dimension. - - Raises: - ValueError: When the input shape is not valid. - """ - if any(size <= 0 for size in shape): - raise ValueError('The shape may contain only positive numbers') - return (None, *shape) - - @dataclasses.dataclass(frozen=True) class GraphNetworkLayer: """Specifies one segment of the pipeline of the graph network.