We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 7fd3b4e + a1eb05a commit 954c41dCopy full SHA for 954c41d
1 file changed
gematria/model/python/model_base.py
@@ -1312,6 +1312,9 @@ def compute_loss_tensor(self, schedule: FeedDict):
1312
)
1313
1314
1315
+ def _get_trainable_variables(self):
1316
+ return self.trainable_variables
1317
+
1318
def train_batch(
1319
self,
1320
schedule: FeedDict,
@@ -1347,10 +1350,11 @@ def train_batch(
1347
1350
]
1348
1351
1349
1352
1353
+ trainable_variables = self._get_trainable_variables()
1354
variables = (
1355
[variable.deref() for variable in variables]
1356
if variables
- else self.trainable_variables
1357
+ else trainable_variables
1358
1359
1360
grads = tape.gradient(loss_tensor, variables)
0 commit comments