Skip to content

Make GraphNetworkLayer subclass from tf.Module#337

Merged
boomanaiden154 merged 4 commits into
mainfrom
users/boomanaiden154/make-graphnetworklayer-subclass-from-tfmodule
May 3, 2025
Merged

Make GraphNetworkLayer subclass from tf.Module#337
boomanaiden154 merged 4 commits into
mainfrom
users/boomanaiden154/make-graphnetworklayer-subclass-from-tfmodule

Conversation

@boomanaiden154

@boomanaiden154 boomanaiden154 commented Apr 26, 2025

Copy link
Copy Markdown
Collaborator

We cannot make the dataclass frozen when doing this, but this enables TF
to automatically find trainable variables within GraphNetworkLayer
objects which means we can get rid of the somewhat hacky
_get_trainable_variables function that subclasses were supposed to
override. This successfully trains models that would otherwise fail to
converge if none of the graph layers were trainable.

This closes #323.

@boomanaiden154 boomanaiden154 requested a review from ondrasej April 26, 2025 18:19
boomanaiden154 and others added 2 commits April 28, 2025 17:14
Created using spr 1.3.4

[skip ci]
Created using spr 1.3.4
boomanaiden154 added a commit to boomanaiden154/gematria that referenced this pull request Apr 30, 2025
We cannot make the dataclass frozen when doing this, but this enables TF
to automatically find trainable variables within GraphNetworkLayer
objects which means we can get rid of the somewhat hacky
_get_trainable_variables function that subclasses were supposed to
override. This successfully trains models that would otherwise fail to
converge if none of the graph layers were trainable.

This closes google#323.

Reviewers: ondrasej

Pull Request: google#337
@boomanaiden154 boomanaiden154 changed the base branch from users/boomanaiden154/main.make-graphnetworklayer-subclass-from-tfmodule to main May 3, 2025 02:54
@boomanaiden154 boomanaiden154 merged commit e1a90cf into main May 3, 2025
7 checks passed
@boomanaiden154 boomanaiden154 deleted the users/boomanaiden154/make-graphnetworklayer-subclass-from-tfmodule branch May 3, 2025 02:55
boomanaiden154 added a commit to boomanaiden154/gematria that referenced this pull request May 17, 2025
This patch provides a custom implementation of trainable_variables in
gnn_model_base. Theoretically this should have been made unnecessary by
\google#337, but the interanl version of Tensorflow refuses to recurse into
the modules inside of the GraphNetworkLayer classes. This patch fixes
that by just returning the values regardless.
boomanaiden154 added a commit to boomanaiden154/gematria that referenced this pull request May 17, 2025
This patch provides a custom implementation of trainable_variables in
gnn_model_base. Theoretically this should have been made unnecessary by
\google#337, but the interanl version of Tensorflow refuses to recurse into
the modules inside of the GraphNetworkLayer classes. This patch fixes
that by just returning the values regardless.

Pull Request: google#341
boomanaiden154 added a commit that referenced this pull request May 19, 2025
This patch provides a custom implementation of trainable_variables in
gnn_model_base. Theoretically this should have been made unnecessary by
\#337, but the interanl version of Tensorflow refuses to recurse into
the modules inside of the GraphNetworkLayer classes. This patch fixes
that by just returning the values regardless.

Reviewers: orodley, virajbshah, ondrasej

Reviewed By: ondrasej

Pull Request: #341
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[TF2] Make GraphNetworkLayer subclass from tf.Module

2 participants