|
3 | 3 |
|
4 | 4 | import torch |
5 | 5 | import torch.nn.functional as F |
6 | | -from torch import nn, Tensor, tensor, is_tensor, cat, stack |
| 6 | +from torch import Tensor, tensor, is_tensor, cat, stack |
7 | 7 | from torch.nn import Embedding, Linear, Module, ModuleList |
8 | 8 | from torch.utils._pytree import tree_map |
9 | 9 |
|
@@ -38,9 +38,9 @@ def __init__( |
38 | 38 | super().__init__() |
39 | 39 | self.num_hiddens_to_concat = num_hiddens_to_concat |
40 | 40 |
|
41 | | - self.norms = ModuleList([nn.RMSNorm(dim) for _ in range(num_hiddens_to_concat)]) |
| 41 | + self.norms = ModuleList([RMSNorm(dim) for _ in range(num_hiddens_to_concat)]) |
42 | 42 |
|
43 | | - self.to_combined = nn.Linear(dim * self.num_hiddens_to_concat, dim, bias = False) |
| 43 | + self.to_combined = Linear(dim * self.num_hiddens_to_concat, dim, bias = False) |
44 | 44 |
|
45 | 45 | def forward( |
46 | 46 | self, |
@@ -132,9 +132,12 @@ def forward( |
132 | 132 | *, |
133 | 133 | labels = None, |
134 | 134 | detach_hiddens = True, |
135 | | - one_step_grad = True |
| 135 | + one_step_grad = True, |
| 136 | + reasoning_steps = None |
136 | 137 | ): |
137 | 138 |
|
| 139 | + reasoning_steps = default(reasoning_steps, self.reasoning_steps) |
| 140 | + |
138 | 141 | if detach_hiddens: |
139 | 142 | hiddens = tree_map_tensor(hiddens, lambda t: t.detach()) |
140 | 143 |
|
@@ -184,7 +187,7 @@ def evaluate_network_( |
184 | 187 | context = torch.no_grad if one_step_grad else nullcontext |
185 | 188 |
|
186 | 189 | with context(): |
187 | | - for index in range(self.reasoning_steps * self.lowest_steps_per_reasoning_step - 1): |
| 190 | + for index in range(reasoning_steps * self.lowest_steps_per_reasoning_step - 1): |
188 | 191 | iteration = index + 1 |
189 | 192 |
|
190 | 193 | for network_index, (network, hidden_combine, evaluate_network_at) in enumerate(zip(self.networks, self.hidden_combiners, self.evaluate_networks_at)): |
|
0 commit comments