We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3ef802f commit d3e38a0Copy full SHA for d3e38a0
HRM/hrm.py
@@ -1,4 +1,5 @@
1
from __future__ import annotations
2
+from contextlib import nullcontext
3
4
import torch
5
from torch import nn, Tensor, tensor, is_tensor
@@ -90,7 +91,8 @@ def forward(
90
91
hiddens: tuple[Tensor, ...] | None = None,
92
*,
93
labels = None,
- detach_hiddens = True
94
+ detach_hiddens = True,
95
+ one_step_grad = True
96
):
97
98
if detach_hiddens:
@@ -102,7 +104,9 @@ def forward(
102
104
103
105
# network as they proposed - following figure 4
106
- with torch.no_grad():
107
+ context = torch.no_grad if one_step_grad else nullcontext
108
+
109
+ with context():
110
for index in range(self.reasoning_steps * self.lowest_steps_per_reasoning_step - 1):
111
iteration = index + 1
112
0 commit comments