-
I am trying to implement the retnet paper and have been having some trouble with the logic of having these various different modes, i.e. parallel, chunkwise and recurrent. At the moment I have implemented retention with all three modes that work on their own but am having trouble figuring out how to efficiently perform the modes in tandem, e.g. the recurrence at test time while also being able to perform the parallel mode at train time. Are there any suggestions for this? perhaps there could be a way to turn the scan on and off? Maybe transferring parameters between two separately initialized instances of the same module? Im not really sure and im kinda stumped. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The best thing I could find was this example: class MLP(nn.Module):
num_layers: int
depth: int
use_scan: bool
@nn.compact
def __call__(self, x):
if self.use_scan:
x, _ = nn.scan(DotReluDot, length=self.num_layers,
variable_axes={"params": 0},
split_rngs={"params": True},
metadata_params={nn.PARTITION_NAME: None}
)(self.depth)(x)
else:
for i in range(self.num_layers):
x, _ = DotReluDot(self.depth)(x)
return x however this isnt quite what im looking for. It would be great to be able to pass the "mode" I would like to be in via the forward pass. Im not totally sure if this correct but let me modify the above code to what I wish could be performed: class MLP(nn.Module):
num_layers: int
depth: int
@nn.compact
def __call__(self, x, use_scan = True):
if use_scan:
x, _ = nn.scan(DotReluDot, length=self.num_layers,
variable_axes={"params": 0},
split_rngs={"params": True},
metadata_params={nn.PARTITION_NAME: None}
)(self.depth)(x)
else:
for i in range(self.num_layers):
x, _ = DotReluDot(self.depth)(x)
return x |
Beta Was this translation helpful? Give feedback.
Hey @BeeGass, this pattern you show here is what is most commonly used e.g in MaxText. Flax doesn't have an option to run the scan logic eagerly.