Skip to content

Commit d3e38a0

Browse files
committed
exploration
1 parent 3ef802f commit d3e38a0

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

HRM/hrm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
from contextlib import nullcontext
23

34
import torch
45
from torch import nn, Tensor, tensor, is_tensor
@@ -90,7 +91,8 @@ def forward(
9091
hiddens: tuple[Tensor, ...] | None = None,
9192
*,
9293
labels = None,
93-
detach_hiddens = True
94+
detach_hiddens = True,
95+
one_step_grad = True
9496
):
9597

9698
if detach_hiddens:
@@ -102,7 +104,9 @@ def forward(
102104

103105
# network as they proposed - following figure 4
104106

105-
with torch.no_grad():
107+
context = torch.no_grad if one_step_grad else nullcontext
108+
109+
with context():
106110
for index in range(self.reasoning_steps * self.lowest_steps_per_reasoning_step - 1):
107111
iteration = index + 1
108112

0 commit comments

Comments
 (0)