You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: HRM/hrm.py
+17-8Lines changed: 17 additions & 8 deletions
Original file line number
Diff line number
Diff line change
@@ -1,5 +1,6 @@
1
1
from __future__ importannotations
2
2
fromcontextlibimportnullcontext
3
+
fromrandomimportrandrange, random
3
4
4
5
importtorch
5
6
importtorch.nn.functionalasF
@@ -25,6 +26,9 @@ def default(v, d):
25
26
deflast(arr):
26
27
returnarr[-1]
27
28
29
+
defsatisfy_prob(prob):
30
+
returnrandom() <prob
31
+
28
32
defdivisible_by(num, den):
29
33
return (num%den) ==0
30
34
@@ -72,7 +76,7 @@ def __init__(
72
76
num_tokens,
73
77
reasoning_steps=2, # N in the paper - the number of forward evals for the last network (highest hierarchy) above
74
78
relative_period: int|tuple[int, ...] =2, # the relative period for each network evaluation call to the one just previous - in the paper, they do 2 networks with a period of 2
75
-
min_reasoning_steps=1,
79
+
min_reasoning_steps_epsilon_prob=0.5, # they stochastically choose the minimum segment from 2 .. max with this probability, and 1 step the rest of the time
0 commit comments