Skip to content

Commit 268e75f

Browse files
rebase
Created using spr 1.3.4
2 parents d74f8d5 + c08fda8 commit 268e75f

2 files changed

Lines changed: 11 additions & 1 deletion

File tree

gematria/granite/python/gnn_model_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,12 @@ def initialize(self):
310310
tf_keras.layers.LayerNormalization(name=globals_layer_norm_name)
311311
)
312312

313+
def _get_trainable_variables(self):
314+
trainable_variables = list(super()._get_trainable_variables())
315+
for layer in self._graph_network:
316+
trainable_variables.extend(layer.module.trainable_variables)
317+
return trainable_variables
318+
313319
# @Override
314320
def _forward(self, feed_dict):
315321
graph_tuple_outputs = self._execute_graph_network(feed_dict)

gematria/model/python/model_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,9 @@ def compute_loss_tensor(self, schedule: FeedDict):
13121312
)
13131313
)
13141314

1315+
def _get_trainable_variables(self):
1316+
return self.trainable_variables
1317+
13151318
def train_batch(
13161319
self,
13171320
schedule: FeedDict,
@@ -1347,10 +1350,11 @@ def train_batch(
13471350
]
13481351
)
13491352

1353+
trainable_variables = self._get_trainable_variables()
13501354
variables = (
13511355
[variable.deref() for variable in variables]
13521356
if variables
1353-
else self.trainable_variables
1357+
else trainable_variables
13541358
)
13551359

13561360
grads = tape.gradient(loss_tensor, variables)

0 commit comments

Comments
 (0)