Skip to content

Commit 40060ff

Browse files
committed
sigmoid for q values
1 parent 64d8503 commit 40060ff

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

HRM/hrm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
relative_period: int | tuple[int, ...] = 2, # the relative period for each network evaluation call to the one just previous - in the paper, they do 2 networks with a period of 2
7575
min_reasoning_steps = 1,
7676
max_reasoning_steps = 10,
77+
act_binary_ce_loss_weight = 1.,
7778
ignore_index = -1
7879
):
7980
super().__init__()
@@ -129,6 +130,8 @@ def __init__(
129130

130131
# Q(continue|halt) for their adaptive computation time setup
131132

133+
self.act_binary_ce_loss_weight = act_binary_ce_loss_weight
134+
132135
self.min_reasoning_steps = min_reasoning_steps
133136
self.max_reasoning_steps = max_reasoning_steps
134137

@@ -225,9 +228,9 @@ def evaluate_network_(
225228

226229
highest_hidden = hiddens[self.num_networks - 1]
227230

228-
q_continue, q_halt = self.to_q_continue_halt(highest_hidden)
231+
q_continue, q_halt = self.to_q_continue_halt(highest_hidden).sigmoid()
229232

230-
should_continue_ = q_halt > q_continue
233+
should_continue = q_halt > q_continue
231234

232235
# 1-step gradient learning
233236

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "HRM"
3-
version = "0.0.2"
3+
version = "0.0.3"
44
description = "The proposal from a Singaporean AGI company"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)