Skip to content

Commit 34119c7

Browse files
[π˜€π—½π—Ώ] changes introduced through rebase
Created using spr 1.3.4 [skip ci]
1 parent 4353110 commit 34119c7

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

β€Žgematria/model/python/model_base.pyβ€Ž

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,9 @@ def compute_loss_tensor(self, schedule: FeedDict):
13061306
)
13071307
)
13081308

1309+
def _get_trainable_variables(self):
1310+
return self.trainable_variables
1311+
13091312
def train_batch(
13101313
self,
13111314
schedule: FeedDict,
@@ -1341,10 +1344,11 @@ def train_batch(
13411344
]
13421345
)
13431346

1347+
trainable_variables = self._get_trainable_variables()
13441348
variables = (
13451349
[variable.deref() for variable in variables]
13461350
if variables
1347-
else self.trainable_variables
1351+
else trainable_variables
13481352
)
13491353

13501354
grads = tape.gradient(loss_tensor, variables)

0 commit comments

Comments
Β (0)