We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 70176ae + c7dfb19 commit f95f2c8Copy full SHA for f95f2c8
1 file changed
gematria/model/python/model_base.py
@@ -1306,6 +1306,9 @@ def compute_loss_tensor(self, schedule: FeedDict):
1306
)
1307
1308
1309
+ def _get_trainable_variables(self):
1310
+ return self.trainable_variables
1311
+
1312
def train_batch(
1313
self,
1314
schedule: FeedDict,
@@ -1341,10 +1344,11 @@ def train_batch(
1341
1344
]
1342
1345
1343
1346
1347
+ trainable_variables = self._get_trainable_variables()
1348
variables = (
1349
[variable.deref() for variable in variables]
1350
if variables
- else self.trainable_variables
1351
+ else trainable_variables
1352
1353
1354
grads = tape.gradient(loss_tensor, variables)
0 commit comments