Skip to content

Commit 954c41d

Browse files
rebase
Created using spr 1.3.4
2 parents 7fd3b4e + a1eb05a commit 954c41d

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

gematria/model/python/model_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,9 @@ def compute_loss_tensor(self, schedule: FeedDict):
13121312
)
13131313
)
13141314

1315+
def _get_trainable_variables(self):
1316+
return self.trainable_variables
1317+
13151318
def train_batch(
13161319
self,
13171320
schedule: FeedDict,
@@ -1347,10 +1350,11 @@ def train_batch(
13471350
]
13481351
)
13491352

1353+
trainable_variables = self._get_trainable_variables()
13501354
variables = (
13511355
[variable.deref() for variable in variables]
13521356
if variables
1353-
else self.trainable_variables
1357+
else trainable_variables
13541358
)
13551359

13561360
grads = tape.gradient(loss_tensor, variables)

0 commit comments

Comments
 (0)