Skip to content

Commit f95f2c8

Browse files
rebase
Created using spr 1.3.4
2 parents 70176ae + c7dfb19 commit f95f2c8

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

gematria/model/python/model_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,9 @@ def compute_loss_tensor(self, schedule: FeedDict):
13061306
)
13071307
)
13081308

1309+
def _get_trainable_variables(self):
1310+
return self.trainable_variables
1311+
13091312
def train_batch(
13101313
self,
13111314
schedule: FeedDict,
@@ -1341,10 +1344,11 @@ def train_batch(
13411344
]
13421345
)
13431346

1347+
trainable_variables = self._get_trainable_variables()
13441348
variables = (
13451349
[variable.deref() for variable in variables]
13461350
if variables
1347-
else self.trainable_variables
1351+
else trainable_variables
13481352
)
13491353

13501354
grads = tape.gradient(loss_tensor, variables)

0 commit comments

Comments
 (0)