Skip to content

Commit a1eb05a

Browse files
[π˜€π—½π—Ώ] changes introduced through rebase
Created using spr 1.3.4 [skip ci]
1 parent d78b9c3 commit a1eb05a

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

β€Žgematria/model/python/model_base.pyβ€Ž

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,9 @@ def compute_loss_tensor(self, schedule: FeedDict):
13121312
)
13131313
)
13141314

1315+
def _get_trainable_variables(self):
1316+
return self.trainable_variables
1317+
13151318
def train_batch(
13161319
self,
13171320
schedule: FeedDict,
@@ -1347,10 +1350,11 @@ def train_batch(
13471350
]
13481351
)
13491352

1353+
trainable_variables = self._get_trainable_variables()
13501354
variables = (
13511355
[variable.deref() for variable in variables]
13521356
if variables
1353-
else self.trainable_variables
1357+
else trainable_variables
13541358
)
13551359

13561360
grads = tape.gradient(loss_tensor, variables)

0 commit comments

Comments
Β (0)