How to assign methods as Flax Module attributes? #995
-
Original question from @jacobaustin123: I've been running into issues with Flax when assigning methods as dataclass attributes. Here's a minimal example: class UNet(nn.Module):
nonlinearity = nn.swish
@nn.compact
def __call__(self, x):
return self.nonlinearity(x) This raises an error |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Answer by @hexahedria: Dataclass actually reads the type annotation, so the code below works fine (even with class Foo(nn.Module):
go : Callable = somefunc So moral of the story, put type annotations for the things you want to be dataclass attributes |
Beta Was this translation helpful? Give feedback.
Answer by @hexahedria:
Dataclass actually reads the type annotation, so the code below works fine (even with
Any
), but fails without any type annotation.So moral of the story, put type annotations for the things you want to be dataclass attributes