Skip to content

Commit 08bdd9e

Browse files
committed
cleanup
1 parent 19441ca commit 08bdd9e

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

HRM/hrm.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def __init__(
7979
min_reasoning_steps_epsilon_prob = 0.5, # they stochastically choose the minimum segment from 2 .. max with this probability, and 1 step the rest of the time
8080
max_reasoning_steps = 10,
8181
act_loss_weight = 1.,
82-
ignore_index = -1
82+
discount_factor = 1.,
83+
ignore_index = -1,
8384
):
8485
super().__init__()
8586

@@ -134,6 +135,8 @@ def __init__(
134135

135136
# Q(continue|halt) for their adaptive computation time setup
136137

138+
self.discount_factor = discount_factor
139+
137140
self.act_loss_weight = act_loss_weight
138141

139142
self.min_reasoning_steps_epsilon_prob = min_reasoning_steps_epsilon_prob
@@ -255,15 +258,26 @@ def evaluate_pred():
255258
should_halt = q_halt > q_continue
256259

257260
if return_loss:
261+
262+
# Q_halt
263+
258264
with torch.no_grad():
259265
is_correct = (evaluate_pred().argmax(dim = -1) == labels).all(dim = -1)
260266

261-
halt_target_loss = F.binary_cross_entropy(q_halt, is_correct.float())
267+
halt_target_loss = F.binary_cross_entropy(
268+
q_halt,
269+
is_correct.float()
270+
)
262271

263272
act_losses.append(halt_target_loss)
264273

274+
# Q_continue
275+
265276
if exists(prev_q_continue):
266-
continue_target_loss = F.binary_cross_entropy(prev_q_continue, torch.maximum(q_continue, q_halt))
277+
continue_target_loss = F.binary_cross_entropy(
278+
prev_q_continue,
279+
torch.maximum(q_continue, q_halt) * self.discount_factor
280+
)
267281

268282
act_losses.append(continue_target_loss)
269283

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "HRM-pytorch"
3-
version = "0.0.8"
3+
version = "0.0.9"
44
description = "The proposal from a Singaporean AGI company"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)